From a455d3fa0285fe64b7a7bf8001b8908b0fb8ab7a Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 7 May 2024 23:23:40 +0000 Subject: [PATCH 01/36] Serialize GEMM kernel runs --- test/gemm/gemm_kernel_base_impl.hpp | 57 +++++++++++++++++------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/test/gemm/gemm_kernel_base_impl.hpp b/test/gemm/gemm_kernel_base_impl.hpp index 1e5cdf62..2584093c 100644 --- a/test/gemm/gemm_kernel_base_impl.hpp +++ b/test/gemm/gemm_kernel_base_impl.hpp @@ -644,35 +644,39 @@ namespace rocwmma this->mBeta); // beta }; + hipEvent_t startEvent, stopEvent; + CHECK_HIP_ERROR(hipEventCreate(&startEvent)); + CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); + // Cold runs for frequency warm-up for(uint32_t i = 0; i < mColdRuns; ++i) { rocwmmaKernel(); } - // Use the hot runs for timing - hipEvent_t startEvent, stopEvent; - CHECK_HIP_ERROR(hipEventCreate(&startEvent)); - CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); - CHECK_HIP_ERROR(hipEventRecord(startEvent)); + // Finish cold runs + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + + // Use the hot runs for timing. Ensure sequential execution. + mElapsedTimeMs = 0.0; for(uint32_t i = 0; i < mHotRuns; ++i) { + CHECK_HIP_ERROR(hipEventRecord(startEvent)); rocwmmaKernel(); + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + auto timeMs = 0.0f; + CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); + mElapsedTimeMs += timeMs; } - CHECK_HIP_ERROR(hipEventRecord(stopEvent)); - CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); - - auto timeMs = 0.0f; - CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); // Calculate efficiency auto& deviceInfo = DeviceInfo::instance(); auto devicePeakGFlopsPerSec = deviceInfo->peakGFlopsPerSec(); - - mElapsedTimeMs = float64_t(timeMs); - mTotalGFlops = calculateGFlops(mM, mN, mK); - mMeasuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, mElapsedTimeMs) + mTotalGFlops = calculateGFlops(mM, mN, mK); + mMeasuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, mElapsedTimeMs) * static_cast(mHotRuns); mEfficiency = round(mMeasuredTFlopsPerSec / devicePeakGFlopsPerSec * 100000.0); @@ -802,37 +806,42 @@ namespace rocwmma std::numeric_limits::signaling_NaN()); } + hipEvent_t startEvent, stopEvent; + CHECK_HIP_ERROR(hipEventCreate(&startEvent)); + CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); + // Cold runs for frequency warm-up for(uint32_t i = 0; i < mColdRuns; ++i) { refKernel(); } + // Finish cold runs + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + // Hot runs for timing - hipEvent_t startEvent, stopEvent; - CHECK_HIP_ERROR(hipEventCreate(&startEvent)); - CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); - CHECK_HIP_ERROR(hipEventRecord(startEvent)); + auto elapsedTimeMs = 0.0; for(uint32_t i = 0; i < mHotRuns; ++i) { + CHECK_HIP_ERROR(hipEventRecord(startEvent)); refKernel(); + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + auto timeMs = 0.0f; + CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); + elapsedTimeMs += timeMs; } - CHECK_HIP_ERROR(hipEventRecord(stopEvent)); - CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); - auto timeMs = 0.0f; - CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); CHECK_HIP_ERROR(hipEventDestroy(startEvent)); CHECK_HIP_ERROR(hipEventDestroy(stopEvent)); // Calculate reference efficiency if constexpr(mBenchRef) { - auto& deviceInfo = DeviceInfo::instance(); auto devicePeakGFlopsPerSec = deviceInfo->peakGFlopsPerSec(); - auto elapsedTimeMs = float64_t(timeMs); auto measuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, elapsedTimeMs) * static_cast(mHotRuns); From 5337ba8bdfa490cb73e07fb6186a6cdebe992020 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 27 Jun 2024 16:26:12 +0000 Subject: [PATCH 02/36] First working interleaved 128x128 macro kernel --- library/include/rocwmma/internal/layout.hpp | 160 ++++- .../include/rocwmma/internal/layout_impl.hpp | 648 ++++++++++++++++- .../include/rocwmma/internal/opaque_load.hpp | 25 +- samples/common.hpp | 70 ++ samples/perf_hgemm.cpp | 654 ++++++++++++++---- 5 files changed, 1398 insertions(+), 159 deletions(-) diff --git a/library/include/rocwmma/internal/layout.hpp b/library/include/rocwmma/internal/layout.hpp index c85c0c83..b7dfa5d2 100644 --- a/library/include/rocwmma/internal/layout.hpp +++ b/library/include/rocwmma/internal/layout.hpp @@ -79,6 +79,45 @@ namespace rocwmma uint32_t MaxVectorWidth> struct RowInlineVW; + /////////////////// Interleaved patterns ////////////////// + template // # of splits + struct ColInlineInt; + + template // # of splits + struct ColOrthoInt; + + template // # of splits + struct RowInlineInt; + + template // # of splits + struct RowOrthoInt; + + /////////////////// ////////////////////////////// ////////////////// + } // namespace MatrixLayout // Register layouts describe whether contiguous BlockDim elements are: @@ -90,7 +129,6 @@ namespace rocwmma struct Aos { }; - template struct Soa { @@ -213,8 +251,8 @@ namespace rocwmma MatrixLayout::ColOrthoVW>; using RegisterLayout = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + RegisterLayout::template Aos, + RegisterLayout::template Soa>; // Mapping using MappingUtil = MappingUtil; @@ -254,8 +292,118 @@ namespace rocwmma MatrixLayout::RowOrthoVW>; using RegisterLayout = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + RegisterLayout::template Aos, + RegisterLayout::template Soa>; + + // Mapping + using MappingUtil = MappingUtil; + using MatrixCoordT = typename MappingUtil::MatrixCoordT; + + // Sanity checks + // Must ensure that MaxVectorWidth fits inside the leading dimension + static_assert( + !(is_same_v && (MaxVectorWidth > BlockK)), + "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); + }; + + //////////////// Interleaved layouts ///////////// + + // Col is a layout profile that has the following properties: + // 1. Leading dimension is aligned with column elements of fragment data: + // - BlockDim is assumed the column size, or BlockM dimension. + // - Analogous to capturing columns of 'matrix A'. + // 2. Register layout is dynamic: + // - col_major data is stored in AOS register layout (non-MFMA friendly), and + // - row_major data is stored in SOA register layout (MFMA friendly). + // - Both data layouts cover the same geometric elements (transform friendly). + // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). + // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). + // 5. User must convert to SOA if using with MFMA. + template + struct ColInt + { + // Layouts + using DataLayout = DataLayout::template Array1d; + using MatrixLayout = conditional_t, + MatrixLayout::ColInlineInt, + MatrixLayout::ColOrthoInt>; + using RegisterLayout + = conditional_t, + RegisterLayout::template Aos, + RegisterLayout::template Soa>; + + // Mapping + using MappingUtil = MappingUtil; + using MatrixCoordT = typename MappingUtil::MatrixCoordT; + + // Sanity checks + // Must ensure that MaxVectorWidth fits inside the leading dimension + static_assert( + !(is_same_v && (MaxVectorWidth > BlockK)), + "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); + }; + + // Row is a layout profile that has the following properties: + // 1. Leading dimension is aligned with row elements of fragment data: + // - BlockDim is assumed the row size, or BlockN dimension. + // - Analogous to capturing rows of 'matrix B' or 'accumulator'. + // 2. Register layout is dynamic: + // - row_major data is stored in AOS register layout (non-MFMA friendly), and + // - col_major data is stored in SOA register layout (MFMA friendly). + // - Both data layouts cover the same geometric elements (transform friendly). + // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). + // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). + // 5. User must convert to SOA if using with MFMA. + template + struct RowInt + { + // Layouts + using DataLayout = DataLayout::template Array1d; + using MatrixLayout = conditional_t, + MatrixLayout::RowInlineInt, + MatrixLayout::RowOrthoInt>; + using RegisterLayout + = conditional_t, + RegisterLayout::template Aos, + RegisterLayout::template Soa>; // Mapping using MappingUtil = MappingUtil; @@ -267,7 +415,7 @@ namespace rocwmma !(is_same_v && (MaxVectorWidth > BlockK)), "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; - + } // namespace FragmentLayout /// diff --git a/library/include/rocwmma/internal/layout_impl.hpp b/library/include/rocwmma/internal/layout_impl.hpp index 2ef65dfa..cdd965a5 100644 --- a/library/include/rocwmma/internal/layout_impl.hpp +++ b/library/include/rocwmma/internal/layout_impl.hpp @@ -257,6 +257,8 @@ namespace rocwmma return make_coord2d(cumBlockDimOffsetX, cumVWOffsetY + cumBlockKOffsetY); } + + ROCWMMA_DEVICE static inline auto debug() {} }; /* Pattern that maps threads to matrix columns and assumes @@ -518,12 +520,544 @@ namespace rocwmma * (int32_t)Traits::BlockKStride_Y; int32_t cumBlockDimOffsetX = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::BlockDimStride_X; + * (int32_t)Traits::BlockKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + ROCWMMA_DEVICE static inline auto debug() {} + }; + + //////////////// Interleaved patterns //////////////////////////////////// + template // # of splits + struct ColInlineInt + { + using IOTraits = IOTraits; + struct Traits + { + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 0u, + BlockKStride_Y = 1u, + + VWStride_X = VectorWidth, + VWStride_Y = 0u, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = KPerThread / BlockKStride_Y, + VWSegs = DimPerThread / VWStride_X, + }; + + // Check VectorWidth validity + static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); + static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, + "DimPerThread not a multiple of VectorWidth"); + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowInlineInt; + + using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + + return make_vector((uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + ROCWMMA_DEVICE static inline auto debug() + { + if(threadIdx.x == 0 && threadIdx.y == 0) + { + printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + (uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); + + printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " + "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + (uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y, + (uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y, + (uint32_t)Traits::VWStride_X, + (uint32_t)Traits::VWStride_Y); + } + if(threadIdx.x <= 63 && threadIdx.y == 0) + { + printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", + threadIdx.x, + get<0>(baseOffset()), + get<1>(baseOffset())); + } + } + }; + template // # of splits + struct ColOrthoInt + { + using IOTraits = IOTraits; + struct Traits + { + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 1u, + BlockKStride_Y = 0u, + + VWStride_X = 0u, + VWStride_Y = VectorWidth, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = DimPerThread / BlockKStride_X, + VWSegs = KPerThread / VWStride_Y, + }; + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // Check VectorWidth validity + static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); + static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, + "KPerThread not a multiple of VectorWidth"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowOrthoInt; + + using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments + (uint32_t)Traits::BlockKSegs, // BlockK Segments + (uint32_t)Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); } + + ROCWMMA_DEVICE static inline auto debug() + { + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); + // } + } }; + template + struct RowInlineInt + { + // RowInlineVW is orthogonal to ColInlineVW, therefore we can use reversed coordinates + struct Traits + { + using OrthoLayout = ColInlineInt; + + using MatrixCoordT = Coord2d; + }; + + // Matrix coord offsets + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return swap(Traits::OrthoLayout::baseOffset()); + } + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return Traits::OrthoLayout::strideCounts(); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + auto t = Traits::OrthoLayout::strides(); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + } + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto debug() + { + Traits::OrthoLayout::debug(); + } + }; + + template + struct RowOrthoInt + { + // RowOrthoVW is orthogonal to ColOrthoVW, therefore we can use reversed coordinates + struct Traits + { + using OrthoLayout = ColOrthoInt; + + using MatrixCoordT = Coord2d; + }; + + // Matrix coord offsets + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return swap(Traits::OrthoLayout::baseOffset()); + } + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return Traits::OrthoLayout::strideCounts(); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + auto t = Traits::OrthoLayout::strides(); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto debug() {} + }; + + /////////////////////////// + template (t)), swap(get<1>(t)), swap(get<2>(t))); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT @@ -568,6 +1101,8 @@ namespace rocwmma { return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); } + + ROCWMMA_DEVICE static inline auto debug() {} }; template (t)), swap(get<1>(t)), swap(get<2>(t))); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT @@ -614,6 +1148,8 @@ namespace rocwmma { return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); } + + ROCWMMA_DEVICE static inline auto debug() {} }; } // namespace MatrixLayout @@ -738,6 +1274,102 @@ namespace rocwmma template ColInlineVW; }; + template + struct OrthogonalLayout> + { + using Type = MatrixLayout::template RowOrthoInt; + }; + + template + struct OrthogonalLayout> + { + using Type = MatrixLayout::template RowInlineInt; + }; + + template + struct OrthogonalLayout> + { + using Type = MatrixLayout::template ColOrthoInt; + }; + + template + struct OrthogonalLayout> + { + using Type = MatrixLayout::template ColInlineInt; + }; + // Register layouts template struct OrthogonalLayout> @@ -757,9 +1389,9 @@ namespace rocwmma // In general, assume that an orthogonal layout has been assigned template - struct is_orthogonal : public integral_constant< - bool, - is_same_v, RhsDataLayout>> + struct is_orthogonal + : public integral_constant, RhsDataLayout>> { }; diff --git a/library/include/rocwmma/internal/opaque_load.hpp b/library/include/rocwmma/internal/opaque_load.hpp index c14e1978..3fadf97b 100644 --- a/library/include/rocwmma/internal/opaque_load.hpp +++ b/library/include/rocwmma/internal/opaque_load.hpp @@ -61,7 +61,8 @@ namespace rocwmma typename DataT, class DataLayout, class MatrixLayout, - uint32_t VectorWidth> + uint32_t VectorWidth, + bool Debug = false> struct OpaqueLoad { using IOTraits = IOTraits; @@ -78,10 +79,7 @@ namespace rocwmma // Outer loop = index 0, // Inner loop = index N-1 - template + template ROCWMMA_DEVICE static inline auto unroll_right(Iterator& out, DataT const* dataPtr, uint32_t ldm, @@ -94,6 +92,14 @@ namespace rocwmma // Last depth layer will invoke the load if constexpr(Depth == (VecTraits>::size() - 1u)) { + if constexpr(Debug) + { + printf("Depth: %d, StrideCount: %d\n", Depth, get(strideCounts)); + printf("StrideX: %d, StrideY: %d\n", + get<0>(get(strides2d)), + get<1>(get(strides2d))); + printf("Executing!\n"); + } #pragma unroll for(int i = 0; i < strideCount; i++) { @@ -105,6 +111,14 @@ namespace rocwmma // Recurse to the next nested layer else { + if constexpr(Debug) + { + printf("Depth: %d, StrideCount: %d\n", Depth, get(strideCounts)); + printf("StrideX: %d, StrideY: %d\n", + get<0>(get(strides2d)), + get<1>(get(strides2d))); + printf("Recursing!\n"); + } #pragma unroll for(int i = 0; i < strideCount; i++) { @@ -117,6 +131,7 @@ namespace rocwmma ROCWMMA_DEVICE static void exec(typename Traits::OutputT& data, DataT const* dataPtr, uint32_t ldm) { + //MatrixLayout::debug(); // Arrange wave threads to starting matrix layout offsets. auto baseOffset2d = MatrixLayout::baseOffset(); auto it = makeVectorIterator(data).begin(); diff --git a/samples/common.hpp b/samples/common.hpp index 247759f9..2ef073c9 100644 --- a/samples/common.hpp +++ b/samples/common.hpp @@ -198,6 +198,15 @@ __host__ static inline void } } +template +__host__ static inline void fillVal(DataT* mat, uint32_t m, uint32_t n, DataT val = 1) +{ + for(int i = 0; i < m * n; ++i) + { + mat[i] = val; + } +} + // Host matrix data random initialization template __host__ static inline void fillRand(DataT* mat, uint32_t m, uint32_t n) @@ -223,6 +232,67 @@ __host__ static inline void fillRand(DataT* mat, uint32_t m, uint32_t n) } } +#include + +template +__host__ static inline void fillEnc(DataT* mat, uint32_t m, uint32_t n) +{ + using EncT = std::conditional_t; + //#pragma omp parallel for + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + EncT enc = ((i & 0xFF) << (sizeof(DataT) * 4)) | (j & 0xFF); + //std::cout << "row: " << i << " col: " << j << " :"; + //std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT)*2) << std::right << std::hex << ((i & 0xFF) << (sizeof(DataT)*4)) << " " << (j & 0xFF) << std::endl; + //std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT)*2) << std::right << std::hex << enc << std::endl; + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + mat[idx] = reinterpret_cast(enc); + } + } +} + +template +__host__ static inline void printEnc(DataT* mat, uint32_t m, uint32_t n) +{ + using EncT = std::conditional_t; + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT) * 2) << std::right + << std::hex << reinterpret_cast(mat[idx]) << " "; + //std::cout << reinterpret_cast(mat[idx]) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +__host__ static inline void printData(DataT* mat, uint32_t m, uint32_t n) +{ + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + std::cout << std::setw(8) << std::right << float(mat[idx]) << " "; + //std::cout << reinterpret_cast(mat[idx]) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + // Host GEMM validation template ( - ldsAddr, applyDataLayout(applyTranspose(grBuffB)), ldsld, waveIndexB); + ldsAddr, + applyDataLayout(applyTranspose(grBuffB)), + ldsld, + waveIndexB); } -// Local A reads for warp tile gemm, non-cooperative +// Global read (macro tile) +using LRBuffA = fragment; +using LRBuffB = ApplyTranspose_t; +using GRBuffC = fragment; +using AccumBuffInt = fragment; + ROCWMMA_DEVICE static inline void - localReadA(MfmaFragA (&fragsA)[BLOCKS_X], InputT const* ldsAddrA, uint32_t ldsld) + localReadA(LRBuffA& fragsA, InputT const* ldsAddrA, uint32_t ldsld) { - using FragShape = GetIOShape_t; - using Mapper1d = GetDataLayout_t; + constexpr uint32_t VW = 4; - // Each A block is stacked vertically in LDS - auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + using Profile = rocwmma::LayoutProfile:: + ColInt; -#pragma unroll - for(int i = 0; i < BLOCKS_X; i++) - { - LRFragA tmp; - load_matrix_sync(tmp, ldsAddrA, ldsld); - fragsA[i] = applyDataLayout(tmp); + using DataLayout = typename Profile::DataLayout; + using MatrixLayout = typename Profile::MatrixLayout; - ldsAddrA += blockStep; - } + using Loader = OpaqueLoad; + + // Load then implicit pack + Loader::exec(fragsA.mAccess, ldsAddrA, ldsld); + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63) + // { + // auto reg = 0u; + // auto x0 = fragsA.mAccess.data[0]; + // auto x1 = fragsA.mAccess.data[1]; + // auto x2 = fragsA.mAccess.data[2]; + // auto x3 = fragsA.mAccess.data[3]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } } // Local B reads for warp tile gemm, non-cooperative ROCWMMA_DEVICE static inline void - localReadB(MfmaFragB (&fragsB)[BLOCKS_Y], InputT const* ldsAddrB, uint32_t ldsld) + localReadB(LRBuffB& fragsB, InputT const* ldsAddrB, uint32_t ldsld) { - using FragShape = GetIOShape_t; - using Mapper1d = GetDataLayout_t; + // How to choose? Comes from the IOConfig? + constexpr uint32_t VW = 4; - // Each B block is stacked vertically in LDS - auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + using Profile = rocwmma::LayoutProfile:: + ColInt; -#pragma unroll - for(int i = 0; i < BLOCKS_Y; i++) - { - LRFragB tmp; - load_matrix_sync(tmp, ldsAddrB, ldsld); + using MatrixLayout = typename Profile::MatrixLayout; + using DataLayout = typename Profile::DataLayout; - // Transform back to MFMA tile - fragsB[i] = applyDataLayout(applyTranspose(tmp)); + using Loader = OpaqueLoad; - ldsAddrB += blockStep; - } + // Load then implicit pack + Loader::exec(reinterpret_cast(fragsB).mAccess, ldsAddrB, ldsld); + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63) + // { + // auto reg = 0u; + // auto x0 = fragsB.mAccess.data[0]; + // auto x1 = fragsB.mAccess.data[1]; + // auto x2 = fragsB.mAccess.data[2]; + // auto x3 = fragsB.mAccess.data[3]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } } // Global C reads for warp tile gemm, non-cooperative -ROCWMMA_DEVICE static inline void - globalReadC(MfmaFragC (&fragC)[BLOCKS_X][BLOCKS_Y], OutputT const* gAddrC, uint32_t ldc) +ROCWMMA_DEVICE static inline void globalReadC(GRBuffC& fragsC, OutputT const* gAddrC, uint32_t ldc) { - using FragShape = GetIOShape_t; - using Mapper1d = GetDataLayout_t; - - // Iterative offsets for each C block in the wave tile - auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldc); - auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldc); - -#pragma unroll - for(int i = 0; i < BLOCKS_X; i++) + // How to choose? Comes from the IOConfig? + constexpr uint32_t VW = 4; + + using Profile = rocwmma::LayoutProfile:: + RowInt; + + using MatrixLayout = typename Profile::MatrixLayout; + using DataLayout = typename Profile::DataLayout; + + using Loader = OpaqueLoad; + + // Load then implicit pack + GRBuffC tmp; + Loader::exec(tmp.mAccess, gAddrC, ldc); + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // auto reg = 0u; + // auto x0 = tmp.mAccess.data[0]; + // auto x1 = tmp.mAccess.data[1]; + // auto x2 = tmp.mAccess.data[2]; + // auto x3 = tmp.mAccess.data[3]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + //MatrixLayout::debug(); { - auto offsetY = 0u; + // Post load to accum format + #pragma unroll - for(int j = 0; j < BLOCKS_Y; j++) + for(int i = 0; i < 4u; i++) { - load_matrix_sync(fragC[i][j], gAddrC + offsetY, ldc); - offsetY += blockStepY; + fragsC.mAccess.data[0 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 0]; + fragsC.mAccess.data[1 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 1]; + fragsC.mAccess.data[2 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 2]; + fragsC.mAccess.data[3 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 3]; + + fragsC.mAccess.data[0 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 0]; + fragsC.mAccess.data[1 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 1]; + fragsC.mAccess.data[2 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 2]; + fragsC.mAccess.data[3 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 3]; + + fragsC.mAccess.data[0 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 0]; + fragsC.mAccess.data[1 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 1]; + fragsC.mAccess.data[2 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 2]; + fragsC.mAccess.data[3 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 3]; + + fragsC.mAccess.data[0 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 0]; + fragsC.mAccess.data[1 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 1]; + fragsC.mAccess.data[2 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 2]; + fragsC.mAccess.data[3 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 3]; } - gAddrC += blockStepX; } + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // auto reg = 0u; + // auto x0 = fragsC.mAccess.data[12]; + // auto x1 = fragsC.mAccess.data[13]; + // auto x2 = fragsC.mAccess.data[14]; + // auto x3 = fragsC.mAccess.data[15]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } } // Global D reads for warp tile gemm, non-cooperative -ROCWMMA_DEVICE static inline void - globalWriteD(OutputT* gAddrD, MfmaFragD const (&fragsD)[BLOCKS_X][BLOCKS_Y], uint32_t ldd) +ROCWMMA_DEVICE static inline void globalWriteD(OutputT* gAddrD, GRBuffC const& fragsD, uint32_t ldd) { - using FragShape = GetIOShape_t; - using Mapper1d = GetDataLayout_t; - - // Iterative offsets for each D block in the warp tile - auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldd); - auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldd); - + // How to choose? Comes from the IOConfig? + constexpr uint32_t VW = 4; + + using Profile = rocwmma::LayoutProfile:: + RowInt; + + using MatrixLayout = typename Profile::MatrixLayout; + using DataLayout = typename Profile::DataLayout; + + using Storer = OpaqueStore; + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // auto reg = 0u; + // auto x0 = fragsD.mAccess.data[0]; + // auto x1 = fragsD.mAccess.data[16]; + // auto x2 = fragsD.mAccess.data[32]; + // auto x3 = fragsD.mAccess.data[48]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + + // Pre-store to output fmt + GRBuffC tmp; + // tmp.mAccess.data[0] = fragsD.mAccess.data[0]; + // tmp.mAccess.data[1] = fragsD.mAccess.data[16]; + // tmp.mAccess.data[2] = fragsD.mAccess.data[32]; + // tmp.mAccess.data[3] = fragsD.mAccess.data[48]; + // tmp.mAccess.data[4] = fragsD.mAccess.data[4]; + // tmp.mAccess.data[5] = fragsD.mAccess.data[20]; + // tmp.mAccess.data[6] = fragsD.mAccess.data[36]; + // tmp.mAccess.data[7] = fragsD.mAccess.data[52]; + // tmp.mAccess.data[8] = fragsD.mAccess.data[8]; + // tmp.mAccess.data[9] = fragsD.mAccess.data[24]; + // tmp.mAccess.data[10] = fragsD.mAccess.data[40]; + // tmp.mAccess.data[11] = fragsD.mAccess.data[56]; + // tmp.mAccess.data[12] = fragsD.mAccess.data[12]; + // tmp.mAccess.data[13] = fragsD.mAccess.data[28]; + // tmp.mAccess.data[14] = fragsD.mAccess.data[44]; + // tmp.mAccess.data[15] = fragsD.mAccess.data[60]; + // tmp.mAccess.data[16] = fragsD.mAccess.data[1]; + // tmp.mAccess.data[17] = fragsD.mAccess.data[17]; + // tmp.mAccess.data[18] = fragsD.mAccess.data[33]; + // tmp.mAccess.data[19] = fragsD.mAccess.data[49]; + // tmp.mAccess.data[20] = fragsD.mAccess.data[5]; + // tmp.mAccess.data[21] = fragsD.mAccess.data[21]; + // tmp.mAccess.data[22] = fragsD.mAccess.data[37]; + // tmp.mAccess.data[23] = fragsD.mAccess.data[53]; + // tmp.mAccess.data[24] = fragsD.mAccess.data[9]; + // tmp.mAccess.data[25] = fragsD.mAccess.data[25]; + // tmp.mAccess.data[26] = fragsD.mAccess.data[41]; + // tmp.mAccess.data[27] = fragsD.mAccess.data[57]; + // tmp.mAccess.data[28] = fragsD.mAccess.data[13]; + // tmp.mAccess.data[29] = fragsD.mAccess.data[29]; + // tmp.mAccess.data[30] = fragsD.mAccess.data[45]; + // tmp.mAccess.data[31] = fragsD.mAccess.data[61]; + // tmp.mAccess.data[32] = fragsD.mAccess.data[2]; + // tmp.mAccess.data[33] = fragsD.mAccess.data[18]; + // tmp.mAccess.data[34] = fragsD.mAccess.data[34]; + // tmp.mAccess.data[35] = fragsD.mAccess.data[50]; + // tmp.mAccess.data[36] = fragsD.mAccess.data[6]; + // tmp.mAccess.data[37] = fragsD.mAccess.data[22]; + // tmp.mAccess.data[38] = fragsD.mAccess.data[38]; + // tmp.mAccess.data[39] = fragsD.mAccess.data[54]; + // tmp.mAccess.data[40] = fragsD.mAccess.data[10]; + // tmp.mAccess.data[41] = fragsD.mAccess.data[26]; + // tmp.mAccess.data[42] = fragsD.mAccess.data[42]; + // tmp.mAccess.data[43] = fragsD.mAccess.data[58]; + // tmp.mAccess.data[44] = fragsD.mAccess.data[14]; + // tmp.mAccess.data[45] = fragsD.mAccess.data[30]; + // tmp.mAccess.data[46] = fragsD.mAccess.data[46]; + // tmp.mAccess.data[47] = fragsD.mAccess.data[62]; + // tmp.mAccess.data[48] = fragsD.mAccess.data[3]; + // tmp.mAccess.data[49] = fragsD.mAccess.data[19]; + // tmp.mAccess.data[50] = fragsD.mAccess.data[35]; + // tmp.mAccess.data[51] = fragsD.mAccess.data[51]; + // tmp.mAccess.data[52] = fragsD.mAccess.data[7]; + // tmp.mAccess.data[53] = fragsD.mAccess.data[23]; + // tmp.mAccess.data[54] = fragsD.mAccess.data[39]; + // tmp.mAccess.data[55] = fragsD.mAccess.data[55]; + // tmp.mAccess.data[56] = fragsD.mAccess.data[11]; + // tmp.mAccess.data[57] = fragsD.mAccess.data[27]; + // tmp.mAccess.data[58] = fragsD.mAccess.data[43]; + // tmp.mAccess.data[59] = fragsD.mAccess.data[59]; + // tmp.mAccess.data[60] = fragsD.mAccess.data[15]; + // tmp.mAccess.data[61] = fragsD.mAccess.data[31]; + // tmp.mAccess.data[62] = fragsD.mAccess.data[47]; + // tmp.mAccess.data[63] = fragsD.mAccess.data[63]; #pragma unroll - for(int i = 0; i < BLOCKS_X; i++) + for(int i = 0; i < 4u; i++) { - auto offsetY = 0u; -#pragma unroll - for(int j = 0; j < BLOCKS_Y; j++) - { - store_matrix_sync(gAddrD + offsetY, fragsD[i][j], ldd); - offsetY += blockStepY; - } - gAddrD += blockStepX; + tmp.mAccess.data[i * 16 + 0 + 0] = fragsD.mAccess.data[0 * 16 + 0 + i]; + tmp.mAccess.data[i * 16 + 0 + 1] = fragsD.mAccess.data[1 * 16 + 0 + i]; + tmp.mAccess.data[i * 16 + 0 + 2] = fragsD.mAccess.data[2 * 16 + 0 + i]; + tmp.mAccess.data[i * 16 + 0 + 3] = fragsD.mAccess.data[3 * 16 + 0 + i]; + + tmp.mAccess.data[i * 16 + 4 + 0] = fragsD.mAccess.data[0 * 16 + 4 + i]; + tmp.mAccess.data[i * 16 + 4 + 1] = fragsD.mAccess.data[1 * 16 + 4 + i]; + tmp.mAccess.data[i * 16 + 4 + 2] = fragsD.mAccess.data[2 * 16 + 4 + i]; + tmp.mAccess.data[i * 16 + 4 + 3] = fragsD.mAccess.data[3 * 16 + 4 + i]; + + tmp.mAccess.data[i * 16 + 8 + 0] = fragsD.mAccess.data[0 * 16 + 8 + i]; + tmp.mAccess.data[i * 16 + 8 + 1] = fragsD.mAccess.data[1 * 16 + 8 + i]; + tmp.mAccess.data[i * 16 + 8 + 2] = fragsD.mAccess.data[2 * 16 + 8 + i]; + tmp.mAccess.data[i * 16 + 8 + 3] = fragsD.mAccess.data[3 * 16 + 8 + i]; + + tmp.mAccess.data[i * 16 + 12 + 0] = fragsD.mAccess.data[0 * 16 + 12 + i]; + tmp.mAccess.data[i * 16 + 12 + 1] = fragsD.mAccess.data[1 * 16 + 12 + i]; + tmp.mAccess.data[i * 16 + 12 + 2] = fragsD.mAccess.data[2 * 16 + 12 + i]; + tmp.mAccess.data[i * 16 + 12 + 3] = fragsD.mAccess.data[3 * 16 + 12 + i]; } + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // if(threadIdx.x == 0) + // { + // printf("D Before STORE\n"); + // printf("Count: %d\n", tmp.num_elements); + // } + // auto reg = 0u; + // auto x0 = tmp.mAccess.data[0]; + // auto x1 = tmp.mAccess.data[16]; + // auto x2 = tmp.mAccess.data[32]; + // auto x3 = tmp.mAccess.data[48]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + + // Load then implicit pack + Storer::exec(gAddrD, tmp.mAccess, ldd); } -// Broadcast value to fragments in warp tile -template -ROCWMMA_DEVICE static inline void fill(FragT (&frags)[BLOCKS_X][BLOCKS_Y], - GetDataType_t value) +// Performs warp tile mfma +ROCWMMA_DEVICE static inline void mfma(AccumBuffInt& fragsAccOut, + LRBuffA const& fragsA, + LRBuffB const& fragsB, + AccumBuffInt const& fragsAccIn) { + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // auto x0 = fragsA.mAccess.data[0]; + // auto x1 = fragsA.mAccess.data[1]; + // auto x2 = fragsA.mAccess.data[2]; + // auto x3 = fragsA.mAccess.data[3]; + // printf("(A)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + + // x0 = fragsB.mAccess.data[0]; + // x1 = fragsB.mAccess.data[1]; + // x2 = fragsB.mAccess.data[2]; + // x3 = fragsB.mAccess.data[3]; + // printf("(B)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + // Need to get the MFMA tile size from the IO traits somehow + constexpr static uint32_t MFMFA_TILE = 16u; + + // From here, need to 'unpack' the interleaved data + // Should be 16 registers, need to re-order them in groups of 4 + LRBuffA tmpA; + LRBuffB tmpB; #pragma unroll - for(int i = 0; i < BLOCKS_X; i++) + for(int i = 0; i < 4u; i++) { -#pragma unroll - for(int j = 0; j < BLOCKS_Y; j++) - { - fill_fragment(frags[i][j], value); - } + tmpA.mAccess.data[i * 4 + 0] = fragsA.mAccess.data[0 * 4 + i]; + tmpA.mAccess.data[i * 4 + 1] = fragsA.mAccess.data[1 * 4 + i]; + tmpA.mAccess.data[i * 4 + 2] = fragsA.mAccess.data[2 * 4 + i]; + tmpA.mAccess.data[i * 4 + 3] = fragsA.mAccess.data[3 * 4 + i]; + + tmpB.mAccess.data[i * 4 + 0] = fragsB.mAccess.data[0 * 4 + i]; + tmpB.mAccess.data[i * 4 + 1] = fragsB.mAccess.data[1 * 4 + i]; + tmpB.mAccess.data[i * 4 + 2] = fragsB.mAccess.data[2 * 4 + i]; + tmpB.mAccess.data[i * 4 + 3] = fragsB.mAccess.data[3 * 4 + i]; } -} -// Performs warp tile mfma -ROCWMMA_DEVICE static inline void mfma(MfmaFragAcc (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], - MfmaFragA const (&fragsA)[BLOCKS_X], - MfmaFragB const (&fragsB)[BLOCKS_Y], - MfmaFragAcc const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) -{ + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // auto x0 = tmpA.mAccess.data[12]; + // auto x1 = tmpA.mAccess.data[13]; + // auto x2 = tmpA.mAccess.data[14]; + // auto x3 = tmpA.mAccess.data[15]; + // printf("(A)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + + // x0 = tmpB.mAccess.data[12]; + // x1 = tmpB.mAccess.data[13]; + // x2 = tmpB.mAccess.data[14]; + // x3 = tmpB.mAccess.data[15]; + // printf("(B)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + + // Iterate over MFMA input requirements + // A = 16 regs unpacked, 8 packed + // B = 16 regs unpacked, 8 packed + // Accum = 64 regs unpacked/packed + // MFMA blocks = 16 x 4 regs + // Iterate through A - major + auto bIt = makeVectorIterator<2u>(tmpB.mStorage).begin(); + auto const accumInIt = makeVectorIterator<4u>(fragsAccOut.mStorage).begin(); + auto accumOutIt = makeVectorIterator<4u>(fragsAccOut.mStorage).begin(); + + using MMA = Mfma; + #pragma unroll - for(int i = 0; i < BLOCKS_X; i++) + for(int j = 0; j < 4u; j++) { + auto aIt = makeVectorIterator<2u>(tmpA.mStorage).begin(); #pragma unroll - for(int j = 0; j < BLOCKS_Y; j++) + for(int i = 0; i < 4u; i++) { - mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); + // mma functions operate on packed vectors + *accumOutIt = MMA::exec(*aIt, *bIt, *accumInIt); + aIt++; + accumInIt++; + accumOutIt++; } + bIt++; } + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // if(threadIdx.x == 0) + // { + // printf("Count: %d\n", fragsAccOut.num_elements); + // } + // auto reg = 0u; + // auto x0 = fragsAccOut.mAccess.data[0]; + // auto x1 = fragsAccOut.mAccess.data[1]; + // auto x2 = fragsAccOut.mAccess.data[2]; + // auto x3 = fragsAccOut.mAccess.data[3]; + // printf("Thread %d: %#010x %#010x %#010x %#010x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } +} + +// Broadcast value to fragments in warp tile +template +ROCWMMA_DEVICE static inline void fill(FragT& frags, GetDataType_t value) +{ + fill_fragment(frags, value); } // Uniform multiply - add (FMA) // Performs D = alpha * acc + beta * C, where alpha, beta are uniform scalars -ROCWMMA_DEVICE static inline void uniformFma(MfmaFragD (&fragsD)[BLOCKS_X][BLOCKS_Y], - ComputeT alpha, - MfmaFragAcc const (&fragsAcc)[BLOCKS_X][BLOCKS_Y], - ComputeT beta, - MfmaFragC const (&fragsC)[BLOCKS_X][BLOCKS_Y]) +ROCWMMA_DEVICE static inline void uniformFma(GRBuffC& fragsD, + ComputeT alpha, + AccumBuffInt const& fragsAcc, + ComputeT beta, + GRBuffC const& fragsC) { -#pragma unroll - for(int i = 0; i < BLOCKS_X; i++) + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // if(threadIdx.x == 0) + // { + // printf("Count: %d\n", fragsAcc.num_elements); + // } + // auto reg = 0u; + // auto x0 = fragsAcc.mAccess.data[0]; + // auto x1 = fragsAcc.mAccess.data[1]; + // auto x2 = fragsAcc.mAccess.data[2]; + // auto x3 = fragsAcc.mAccess.data[3]; + // printf("Thread %d: %#010x %#010x %#010x %#010x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // if(threadIdx.x == 0) + // { + // printf("Count: %d\n", fragsC.num_elements); + // } + // auto reg = 0u; + // auto x0 = fragsC.mAccess.data[0]; + // auto x1 = fragsC.mAccess.data[1]; + // auto x2 = fragsC.mAccess.data[2]; + // auto x3 = fragsC.mAccess.data[3]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } + + static constexpr uint32_t ChunkFactor = 2u; + static constexpr uint32_t ChunkSize = 64u / ChunkFactor; + auto dIt = makeVectorIterator(fragsD.mAccess).begin(); + auto const accumIt = makeVectorIterator(fragsAcc.mAccess).begin(); + auto const cIt = makeVectorIterator(fragsC.mAccess).begin(); + + for(int k = 0; k < fragsD.num_elements / ChunkFactor; k++) { -#pragma unroll - for(int j = 0; j < BLOCKS_Y; j++) - { - for(int k = 0; k < fragsD[i][j].num_elements; k++) - { - // Perform computation in ComputeT and cast back to OutputT - fragsD[i][j].x[k] = static_cast( - alpha * fragsAcc[i][j].x[k] + beta * static_cast(fragsC[i][j].x[k])); - } - } + // Perform computation in ComputeT and cast back to OutputT + (*dIt).data[k] = static_cast(alpha * (*accumIt).data[k] + + beta * static_cast((*cIt).data[k])); + } + + dIt++; + accumIt++; + cIt++; + + for(int k = 0; k < fragsD.num_elements / ChunkFactor; k++) + { + // Perform computation in ComputeT and cast back to OutputT + (*dIt).data[k] = static_cast(alpha * (*accumIt).data[k] + + beta * static_cast((*cIt).data[k])); } + + // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) + // { + // if(threadIdx.x == 0) + // { + // printf("D AFTER UNIFORM FMA\n"); + // printf("Count: %d\n", fragsD.num_elements); + // } + // auto reg = 0u; + // auto x0 = fragsD.mAccess.data[0]; + // auto x1 = fragsD.mAccess.data[16]; + // auto x2 = fragsD.mAccess.data[32]; + // auto x3 = fragsD.mAccess.data[48]; + // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); + // } } -ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, - uint32_t n, - uint32_t k, - InputT const* a, - InputT const* b, - OutputT const* c, - OutputT* d, - uint32_t lda, - uint32_t ldb, - uint32_t ldc, - uint32_t ldd, - ComputeT alpha, - ComputeT beta) +//ROCWMMA_KERNEL void gemm_rocwmma_d(uint32_t m, +//ROCWMMA_KERNEL void __attribute__((amdgpu_num_vgpr(0))) gemm_rocwmma_d(uint32_t m, +ROCWMMA_KERNEL void __launch_bounds__(1024) gemm_rocwmma_d( + uint32_t m, + //ROCWMMA_KERNEL void __attribute__((amdgpu_waves_per_eu(1))) gemm_rocwmma_d(uint32_t m, + uint32_t n, + uint32_t k, + InputT const* a, + InputT const* b, + OutputT const* c, + OutputT* d, + uint32_t lda, + uint32_t ldb, + uint32_t ldc, + uint32_t ldd, + ComputeT alpha, + ComputeT beta) { if constexpr(!ROCWMMA_ARCH_HOST) { @@ -637,7 +938,7 @@ ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, /// /// Initialize accumulation frags /// - MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; + AccumBuffInt fragsAcc; fill(fragsAcc, 0.0f); /// @@ -648,19 +949,27 @@ ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, /// /// Accumulate A * B for all mfma frags in warp tile /// + // - LDS Triple buffer + // - LDS no buffer-> tiny m/n large K + // - unroll K to have more work + // - __restrict__ + // for(uint32_t currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) { - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; - - // Local read mfma frags from first LDS buffer - localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); - localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + // Make sure that all waves have finished reading / writing to lds for currentK. + synchronize_workgroup(); // Prefetch next round of global frags globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); + LRBuffA fragsA; + LRBuffB fragsB; + + // Local read mfma frags from first LDS buffer + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + // Advance offsets to next k step globalReadOffsetA += kStepOffsetA; globalReadOffsetB += kStepOffsetB; @@ -672,41 +981,93 @@ ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); - // Make sure that all waves have finished reading / writing to lds for currentK. - synchronize_workgroup(); - // Swap Lds buffers auto* tmp = ldsPtrLo; ldsPtrLo = ldsPtrHi; ldsPtrHi = tmp; + + // Scheduling + + // // VMEM read + // __builtin_amdgcn_sched_group_barrier(32, 2, 0); + // // DS read + // __builtin_amdgcn_sched_group_barrier(256, 16, 0); + // // Non-VMEM + // __builtin_amdgcn_sched_group_barrier(1, 16, 0); + // // MFMA + // __builtin_amdgcn_sched_group_barrier(8, 4, 0); + // // DS read + // __builtin_amdgcn_sched_group_barrier(256, 16, 1); + // // // Non-VMEM + // __builtin_amdgcn_sched_group_barrier(1, 16, 1); + // // MFMA + // __builtin_amdgcn_sched_group_barrier(8, 4, 1); + // // DS write + // __builtin_amdgcn_sched_group_barrier(512, 32, 0); + + ////////// Works good - 127.46 + // VMEM read + __builtin_amdgcn_sched_group_barrier(32, 4, 0); + // DS read + __builtin_amdgcn_sched_group_barrier(256, 64, 0); + // SALU + __builtin_amdgcn_sched_group_barrier(4, 256, 0); + // VALU + __builtin_amdgcn_sched_group_barrier(2, 256, 0); + // MFMA + __builtin_amdgcn_sched_group_barrier(8, 16, 0); + // DS write + __builtin_amdgcn_sched_group_barrier(512, 64, 0); + ////////////////// } + // Make sure that all waves have finished reading / writing to lds for currentK. + synchronize_workgroup(); + /// /// Start loading C /// using MfmaFragCMap1d = GetDataLayout_t; using MfmaFragDMap1d = GetDataLayout_t; - MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; + GRBuffC fragsC; globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - /// - /// Clean up tail A * B - /// - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; + // /// + // /// Clean up tail A * B + // /// + LRBuffA fragsA; + LRBuffB fragsB; - // Local read mfma frags + // // Local read mfma frags localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); mfma(fragsAcc, fragsA, fragsB, fragsAcc); - /// - /// D = alpha * accum + beta * C - /// - MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; + // /// + // /// D = alpha * accum + beta * C + // /// + GRBuffC fragsD; uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); + //globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), reinterpret_cast(fragsAcc), ldd); globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); + + ////////// Works good - 127.46 + // DS read + __builtin_amdgcn_sched_group_barrier(256, 64, 0); + // VMEM read + __builtin_amdgcn_sched_group_barrier(32, 64, 0); + + // MFMA + __builtin_amdgcn_sched_group_barrier(8, 16, 0); + // SALU + __builtin_amdgcn_sched_group_barrier(4, 256, 0); + // VALU + __builtin_amdgcn_sched_group_barrier(2, 512, 0); + + // VMEM write + __builtin_amdgcn_sched_group_barrier(512, 64, 0); + ////////////////// } } @@ -780,6 +1141,12 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, fillRand(matrixA.data(), m, k); fillRand(matrixB.data(), k, n); fillRand(matrixC.data(), m, n); + //fillEnc(matrixA.data(), m, k); + //printEnc(matrixA.data(), m, k); + //fillEnc(matrixB.data(), k, n); + //printEnc(matrixB.data(), k, n); + //fillEnc(matrixC.data(), m, n); + //printEnc(matrixC.data(), m, n); std::cout << "Initializing device data..." << std::endl; @@ -841,7 +1208,7 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, beta); }; - constexpr uint32_t warmups = 2u; + constexpr uint32_t warmups = 50u; constexpr uint32_t recordRuns = 5u; // Warm-up runs, not recorded @@ -888,7 +1255,7 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, << ldc << ", " << ldd << ", " << elapsedTimeMs << ", " << gFlops << ", " << tFlopsPerSec << std::endl; -#if !NDEBUG +#if 1 std::cout << "Validating result with reference..." << std::endl; @@ -918,6 +1285,11 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, auto res = compareEqual(matrixD.data(), matrixD_ref.data(), m * n); + //std::cout << "Reference: \n"; + //printData(matrixD_ref.data(), m, n); + //std::cout << "Actual:\n"; + //printData(matrixD.data(), m, n); + if(std::get<0>(res) == false) { std::cout << "FAILED\n"; @@ -943,5 +1315,7 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, int main() { gemm_test(7168, 7168, 7168, 2, 2); + //gemm_test(8192, 8192, 8192, 2, 2); + //gemm_test(128, 128, 16, 2, 2); return 0; } From ce0f116e2b849778b776548c485c18d0a93797a5 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 9 Sep 2024 14:39:17 +0000 Subject: [PATCH 03/36] Initial layout classes refactor --- .../internal/layout/data_layout_impl.hpp | 79 + .../rocwmma/internal/layout/layout.hpp | 184 +++ .../{layout.hpp => layout/layout_profile.hpp} | 215 +-- .../rocwmma/internal/layout/layout_traits.hpp | 77 + .../internal/layout/layout_traits_impl.hpp | 69 + .../matrix_layout_impl.hpp} | 1071 +++++-------- .../layout/matrix_layout_interleaved_impl.hpp | 1348 +++++++++++++++++ .../internal/layout/register_layout_impl.hpp | 44 + .../include/rocwmma/internal/transforms.hpp | 183 ++- .../rocwmma/internal/transforms_impl.hpp | 1 - .../include/rocwmma/internal/vector_util.hpp | 11 + .../rocwmma/internal/vector_util_impl.hpp | 14 + library/include/rocwmma/rocwmma_impl.hpp | 8 +- .../rocwmma/rocwmma_transforms_impl.hpp | 87 +- 14 files changed, 2487 insertions(+), 904 deletions(-) create mode 100644 library/include/rocwmma/internal/layout/data_layout_impl.hpp create mode 100644 library/include/rocwmma/internal/layout/layout.hpp rename library/include/rocwmma/internal/{layout.hpp => layout/layout_profile.hpp} (68%) create mode 100644 library/include/rocwmma/internal/layout/layout_traits.hpp create mode 100644 library/include/rocwmma/internal/layout/layout_traits_impl.hpp rename library/include/rocwmma/internal/{layout_impl.hpp => layout/matrix_layout_impl.hpp} (52%) create mode 100644 library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp create mode 100644 library/include/rocwmma/internal/layout/register_layout_impl.hpp diff --git a/library/include/rocwmma/internal/layout/data_layout_impl.hpp b/library/include/rocwmma/internal/layout/data_layout_impl.hpp new file mode 100644 index 00000000..eccbb67a --- /dev/null +++ b/library/include/rocwmma/internal/layout/data_layout_impl.hpp @@ -0,0 +1,79 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_DATA_LAYOUT_IMPL_HPP +#define ROCWMMA_DATA_LAYOUT_IMPL_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" +#include "utility/type_traits.hpp" + +namespace rocwmma +{ + // Data layout trait tags are transposes + template <> + struct is_layout_transpose : public true_type + { + }; + + template <> + struct is_layout_transpose : public true_type + { + }; + + // Data layout objects are transposes + template <> + struct is_layout_transpose : public true_type + { + }; + + template <> + struct is_layout_transpose : public true_type + { + }; + + // Data layout trait tag transpose + template <> + struct layout_transpose + { + using type = col_major; + }; + + template <> + struct layout_transpose + { + using type = row_major; + }; + + // Data layout object type transpose + template + struct layout_transpose> + { + using Type = DataLayout::template Array1d>; + }; + +} // namespace rocwmma + +#endif // ROCWMMA_DATA_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp new file mode 100644 index 00000000..0c30d615 --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -0,0 +1,184 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_HPP +#define ROCWMMA_LAYOUT_HPP + +#include "mapping_util.hpp" + +namespace rocwmma +{ + // DataLayout objects map 2D matrix coordinate space to 1D data arrays offsets. + // DataLayoutT tags describe whether consecutive elements are: + // 1. Contiguous rows (row_major) + // 2. Contiguous columns (col_major) + namespace DataLayout + { + /*! \class Array1d + * \brief A class to help map from 2D matrix space to 1D data space. + * @tparam DataLayoutT Meta-tag indicating whether data is stored in + * row_major or col_major order. + */ + template + using Array1d = typename ::rocwmma::detail::template DataSpace; + + /*! \class RowMajor + * \brief Maps 2D matrix space to row_major 1D data space + */ + using RowMajor = Array1d; + + /*! \class ColMajor + * \brief Maps 2D matrix space to col_major 1D data space + */ + using ColMajor = Array1d; + + } // namespace DataLayout + + // Matrix Layouts map thread offsets into 2D matrix coordinate space: + // 1. Base thread offsets + // 2. Stride offsets + // 3. Stride counts + // 4. Per-iteration offsets (stride step based on iteration) + // 5. Cumulative offsets (cumulative stride steps based on iteration) + namespace MatrixLayout + { + /*! \class ColOrthoVW + * \brief A matrix layout that maps contiguous threads to contiguous column elements, in the BlockDim direction. + * VectorWidth elements are mapped orthogonal to the column, in the BlockK Direction. + * @tparam BlockDim The height of the column + * @tparam BlockK The number of columns + * @tparam DataT The datatype + * @tparam VectorWidth The iterative vector width + * @tparam MaxVectorWidth The total vector width + */ + template + struct ColOrthoVW; + + template + struct ColInlineVW; + + template + struct RowOrthoVW; + + template + struct RowInlineVW; + + /////////////////// Interleaved patterns ////////////////// + template // # of splits + struct ColInlineInt; + + template // # of splits + struct ColOrthoInt; + + template // # of splits + struct RowInlineInt; + + template // # of splits + struct RowOrthoInt; + + /////////////////// ////////////////////////////// ////////////////// + + } // namespace MatrixLayout + + // Register layouts describe in-register layout and serve as transform states, or endpoints. + // These are mnemonics which provide: + // 1. A relationship between in-register layouts and combinations of matrix / data layouts. + // 2. Useful parameters that may be used in transformations between endpoints. + // 3. With indications from layout traits, can determine likeness or orthogonality between states. + // Note: For these mnemonics to be useful, there must exist a transformable path between layouts. + // Example: + // Suppose we associate associate fragment register data with Storage upon loading. + // To use the fragment register data with mma functions, we may attempt to transform the data from + // Storage to MmaInput<16> to serve as input to a 16x16xk mma builtin. + namespace RegisterLayout + { + // A mnemonic used to describe the register layout is suitable for input/output + template + struct Storage + { + }; + + // A mnemonic used to describe the register layout is suitable for mma input for A/B + template + struct MmaInput + { + }; + + // A mnemonic used to describe the register layout is suitable for mma input for accumulator input/output + template + struct MmaAcc + { + }; + + } // namespace RegisterLayout + +} // namespace rocwmma + +#include "data_layout_impl.hpp" +#include "matrix_layout_impl.hpp" +#include "matrix_layout_interleaved_impl.hpp" +#include "register_layout_impl.hpp" + +#endif // ROCWMMA_LAYOUT_HPP diff --git a/library/include/rocwmma/internal/layout.hpp b/library/include/rocwmma/internal/layout/layout_profile.hpp similarity index 68% rename from library/include/rocwmma/internal/layout.hpp rename to library/include/rocwmma/internal/layout/layout_profile.hpp index b7dfa5d2..4e243061 100644 --- a/library/include/rocwmma/internal/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout_profile.hpp @@ -23,132 +23,27 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef ROCWMMA_LAYOUT_HPP -#define ROCWMMA_LAYOUT_HPP +#ifndef ROCWMMA_LAYOUT_PROFILE_HPP +#define ROCWMMA_LAYOUT_PROFILE_HPP -#include "mapping_util.hpp" +#include "layout.hpp" namespace rocwmma { - // In relation to matrix space, DataLayouts describe whether consecutive elements in 1D data arrays are: - // 1. Contiguous rows (row_major) - // 2. Contiguous columns (col_major) - namespace DataLayout - { - template - using Array1d = typename ::rocwmma::detail::template DataSpace; - - using RowMajor = Array1d; - using ColMajor = Array1d; - - } // namespace DataLayout - - // In 2D space, Matrix Layouts describe per-thread offset coordinates and iterative spaces - // 1. Base thread offsets - // 2. Stride offsets - // 3. Stride spaces (counts) - // 4. Per-iteration offsets (stride step based on iteration) - // 5. Cumulative offsets (cumulative stride steps based on iteration) - namespace MatrixLayout - { - template - struct ColOrthoVW; - - template - struct ColInlineVW; - - template - struct RowOrthoVW; - - template - struct RowInlineVW; - - /////////////////// Interleaved patterns ////////////////// - template // # of splits - struct ColInlineInt; - - template // # of splits - struct ColOrthoInt; - - template // # of splits - struct RowInlineInt; - - template // # of splits - struct RowOrthoInt; - - /////////////////// ////////////////////////////// ////////////////// - - } // namespace MatrixLayout - - // Register layouts describe whether contiguous BlockDim elements are: - // 1. Captured in the same register lane as if the input were in Array-Of-Structures (AOS) - // 2. Captured across multiple register lanes as if the input were in Structure-Of-Arrays (SOA) - namespace RegisterLayout - { - template - struct Aos - { - }; - template - struct Soa - { - }; - } - - // Layout profiles describe fragments in three mapped spaces: - // 1. DataLayout: data locality in memory space (row_major or col_major) - // 2. MatrixLayout: data locality in matrix space (ColOrthoVW, ColInlineVW, etc.) - // 3. RegisterLayout: data locality in register space (AOS or SOA) + // Layout profiles are high-level objects that describe fragments in three mapped spaces: + // 1. DataLayout: data locality in 1D memory space (row_major or col_major) + // 2. MatrixLayout: data locality in 2D matrix space (ColOrthoVW, ColInlineVW, etc.) + // 3. RegisterLayout: data locality in register space (Storage, or MmaInput) namespace LayoutProfile { // ColNT is a layout profile that has the following properties: // 1. Leading dimension is aligned with column elements of fragment data: // - BlockDim is assumed the column size, or BlockM dimension. // - Analogous to capturing columns of 'matrix A'. - // 2. Register elements are in MFMA friendly, or SOA register layout. - // 3. Register layout does NOT change whether DataLayout is col_major or row_major (fast DataLayoutT change). - // 4. MatrixLayout will capture contiguous column elements across multiple register lanes. - // 5. VectorWidth is fixed to 1 in col_major to ensure #3 (non-optimal). + // 2. When BlockDim is supported by mma, register elements are always in MmaInput friendly register layout. + // 3. Register layout does NOT change whether DataLayout is col_major or row_major (free DataLayoutT change). + // 4. MatrixLayout will capture contiguous column elements across contiguous threads. + // 5. VectorWidth is fixed to 1 in col_major to ensure #4 (non-optimal). template ; + using DataLayout = DataLayout::template Array1d; + using MatrixLayout = conditional_t< is_same_v, MatrixLayout::ColOrthoVW, - MatrixLayout::ColOrthoVW>; - using RegisterLayout = RegisterLayout::template Soa; + MatrixLayout::ColOrthoVW>; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -187,10 +84,10 @@ namespace rocwmma // 1. Leading dimension is aligned with row elements of fragment data: // - BlockDim is assumed the row size, or BlockN dimension. // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. Register elements are in MFMA friendly, or SOA register layout. + // 2. When BlockDim is supported by mma, register elements are always MmaInput friendly register layout. // 3. Register layout does NOT change whether DataLayout is col_major or row_major (fast DataLayoutT change). - // 4. MatrixLayout will capture contiguous row elements across multiple register lanes. - // 5. VectorWidth is fixed to 1 in row_major to ensure #3 (non-optimal). + // 4. MatrixLayout will capture contiguous row elements across contiguous threads. + // 5. VectorWidth is fixed to 1 in row_major to ensure #4 (non-optimal). template ; + using DataLayout = DataLayout::template Array1d; + using MatrixLayout = conditional_t< is_same_v, MatrixLayout::RowOrthoVW, MatrixLayout::RowOrthoVW>; - using RegisterLayout = RegisterLayout::template Soa; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -229,12 +128,11 @@ namespace rocwmma // - BlockDim is assumed the column size, or BlockM dimension. // - Analogous to capturing columns of 'matrix A'. // 2. Register layout is dynamic: - // - col_major data is stored in AOS register layout (non-MFMA friendly), and - // - row_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). + // - col_major data is stored in AOS register layout (non-MmaInput friendly), and + // - row_major data is stored in SOA register layout (MmaInput friendly). // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. + // 5. Must convert to SOA if using with MFMA. template ; + using DataLayout = DataLayout::template Array1d; + using MatrixLayout = conditional_t< is_same_v, MatrixLayout::ColInlineVW, MatrixLayout::ColOrthoVW>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -290,10 +187,8 @@ namespace rocwmma is_same_v, MatrixLayout::RowInlineVW, MatrixLayout::RowOrthoVW>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -346,10 +241,8 @@ namespace rocwmma MaxVectorWidth, MfmaDim, SplitK>>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -400,10 +293,8 @@ namespace rocwmma MaxVectorWidth, MfmaDim, SplitK>>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; + + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -416,38 +307,8 @@ namespace rocwmma "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; - } // namespace FragmentLayout - - /// - /// Helper to ensure layout types are consistent (same, or equivalent) - /// - template - struct ConsistencyCheck : public false_type - { - }; - - /// - /// Check for layout orthogonality - /// - template - struct OrthogonalCheck : public false_type - { - }; - - template - struct OrthogonalLayout; - - template - using orthogonal_layout_t = typename OrthogonalLayout::Type; - - template - struct is_orthogonal; - - template - inline constexpr bool is_orthogonal_v = is_orthogonal::value; + } // namespace LayoutProfile } // namespace rocwmma -#include "layout_impl.hpp" - -#endif // ROCWMMA_LAYOUT_HPP +#endif // ROCWMMA_LAYOUT_PROFILE_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp new file mode 100644 index 00000000..39c1c161 --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_TRAITS_HPP +#define ROCWMMA_LAYOUT_TRAITS_HPP + +#include "utility/type_traits.hpp" + +namespace rocwmma +{ + /*! \class is_layout_same + * \brief Compares layout types are the same, or are equivalent. Similar to is_same, + * however layouts can have an equivalency with small variations input parameters such that they + * are still technically the same. This should be used when comparing any layout types: + * DataLayout, MatrixLayout and RegisterLayout + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + struct is_layout_same : public false_type + { + }; + + /*! \class is_layout_transpose + * \brief Compares layout types if they are transposed with each other. + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + struct is_layout_transpose : public false_type + { + }; + + /*! \class layout_transpose + * \brief Transforms the layout type into its direct transpose. + * @tparam Layout the layout to transpose from + */ + template + struct layout_transpose + { + // using type = ... + }; + + /*! \class layout_transpose_t + * \brief Transforms the layout type into its direct transpose. + * @tparam Layout the layout to transpose from + */ + template + using layout_transpose_t = typename layout_transpose::type; + +} // namespace rocwmma + +#include "layout_traits_impl.hpp" + +#endif // ROCWMMA_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp new file mode 100644 index 00000000..0bcf3a1f --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp @@ -0,0 +1,69 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_LAYOUT_TRAITS_IMPL_HPP + +#include "config.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + // Common helpers for supported traits + namespace detail + { + // Based on the current config, determine the compatibility of the mma dimension + constexpr static inline bool testSupportedMmaDim(uint32_t testDim) + { + return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && testDim == 16u) + || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && (testDim == 16u || testDim == 32u)); + } + + // VW can be changed from vw0 to vw1 as long as they have the same maxVW, and that maxVW + // is a multiple of both vw values + constexpr static inline bool testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) + { + return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); + } + + } // namespace detail + + // Covers all other generic exact layout class matches + + // Self-compare is always true + template + struct is_layout_same : public true_type + { + }; + + // Self-compare is always false + template + struct is_layout_transpose : public false_type + { + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp similarity index 52% rename from library/include/rocwmma/internal/layout_impl.hpp rename to library/include/rocwmma/internal/layout/matrix_layout_impl.hpp index cdd965a5..430321e3 100644 --- a/library/include/rocwmma/internal/layout_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -23,16 +23,15 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef ROCWMMA_LAYOUT_IMPL_HPP -#define ROCWMMA_LAYOUT_IMPL_HPP +#ifndef ROCWMMA_MATRIX_LAYOUT_IMPL_HPP +#define ROCWMMA_MATRIX_LAYOUT_IMPL_HPP -#include "io_traits.hpp" #include "layout.hpp" -#include "mapping_util.hpp" -#include "utils.hpp" +#include "layout_traits.hpp" namespace rocwmma { + // Implementations for the MatrixLayout classes namespace MatrixLayout { @@ -527,442 +526,18 @@ namespace rocwmma ROCWMMA_DEVICE static inline auto debug() {} }; - //////////////// Interleaved patterns //////////////////////////////////// template // # of splits - struct ColInlineInt - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, - - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), - - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, - - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, - - BlockKStride_X = 0u, - BlockKStride_Y = 1u, - - VWStride_X = VectorWidth, - VWStride_Y = 0u, - - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = KPerThread / BlockKStride_Y, - VWSegs = DimPerThread / VWStride_X, - }; - - // Check VectorWidth validity - static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); - static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, - "DimPerThread not a multiple of VectorWidth"); - - // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); - - // Check SplitK validity - static_assert(BlockK >= SplitK, "Invalid SplitK"); - static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); - - // Check MfmaDim validity - static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); - static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowInlineInt; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - - return make_vector((uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); - } - - // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetX = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetX = (int32_t)Traits::VWStride_X; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::SplitKSegs > 1)) - { - // "Reset" cycle - VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::SplitKStride_X); - } - - return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); - } - - // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetX - = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::SplitKStride_X; - - return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); - } - ROCWMMA_DEVICE static inline auto debug() - { - if(threadIdx.x == 0 && threadIdx.y == 0) - { - printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", - (uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); - - printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " - "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", - (uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y, - (uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y, - (uint32_t)Traits::VWStride_X, - (uint32_t)Traits::VWStride_Y); - } - if(threadIdx.x <= 63 && threadIdx.y == 0) - { - printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", - threadIdx.x, - get<0>(baseOffset()), - get<1>(baseOffset())); - } - } - }; - template // # of splits - struct ColOrthoInt - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, - - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), - - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, - - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, - - BlockKStride_X = 1u, - BlockKStride_Y = 0u, - - VWStride_X = 0u, - VWStride_Y = VectorWidth, - - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = DimPerThread / BlockKStride_X, - VWSegs = KPerThread / VWStride_Y, - }; - - // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); - - // Check VectorWidth validity - static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); - static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, - "KPerThread not a multiple of VectorWidth"); - - // Check SplitK validity - static_assert(BlockK >= SplitK, "Invalid SplitK"); - static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); - - // Check MfmaDim validity - static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); - static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowOrthoInt; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); - } - - // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetX = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetX = (int32_t)Traits::VWStride_X; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::SplitKSegs > 1)) - { - // "Reset" cycle - VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::SplitKStride_X); - } - - return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); - } - - // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetX - = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::SplitKStride_X; - - return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); - } - - ROCWMMA_DEVICE static inline auto debug() - { - // if(threadIdx.x == 0 && threadIdx.y == 0) - // { - // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", - // (uint32_t)Traits::SplitKSegs, - // (uint32_t)Traits::BlockKSegs, - // (uint32_t)Traits::VWSegs); - - // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", - // (uint32_t)Traits::SplitKStride_X, - // (uint32_t)Traits::SplitKStride_Y, - // (uint32_t)Traits::BlockKStride_X, - // (uint32_t)Traits::BlockKStride_Y, - // (uint32_t)Traits::VWStride_X, - // (uint32_t)Traits::VWStride_Y); - - // } - // if(threadIdx.x <= 63 && threadIdx.y == 0) - // { - // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); - // } - } - }; - - template - struct RowInlineInt + uint32_t MaxVectorWidth> + struct RowInlineVW { // RowInlineVW is orthogonal to ColInlineVW, therefore we can use reversed coordinates struct Traits { - using OrthoLayout = ColInlineInt; + using OrthoLayout + = ColInlineVW; using MatrixCoordT = Coord2d; }; @@ -995,31 +570,21 @@ namespace rocwmma return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); } - ROCWMMA_DEVICE static inline auto debug() - { - Traits::OrthoLayout::debug(); - } + ROCWMMA_DEVICE static inline auto debug() {} }; template - struct RowOrthoInt + uint32_t MaxVectorWidth> + struct RowOrthoVW { // RowOrthoVW is orthogonal to ColOrthoVW, therefore we can use reversed coordinates struct Traits { - using OrthoLayout = ColOrthoInt; + using OrthoLayout + = ColOrthoVW; using MatrixCoordT = Coord2d; }; @@ -1046,7 +611,6 @@ namespace rocwmma { return swap(Traits::OrthoLayout::incrementalOffset(iteration)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT cumulativeOffset(uint32_t iteration) { @@ -1056,53 +620,29 @@ namespace rocwmma ROCWMMA_DEVICE static inline auto debug() {} }; - /////////////////////////// + } // namespace MatrixLayout + + namespace detail + { + template + struct is_ColOrthoVW : public false_type + { + }; template - struct RowInlineVW + struct is_ColOrthoVW< + MatrixLayout::template ColOrthoVW> + : public true_type { - // RowInlineVW is orthogonal to ColInlineVW, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout - = ColInlineVW; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } + }; - ROCWMMA_DEVICE static inline auto debug() {} + template + struct is_RowOrthoVW : public false_type + { }; template - struct RowOrthoVW + struct is_RowOrthoVW< + MatrixLayout::template RowOrthoVW> + : public true_type { - // RowOrthoVW is orthogonal to ColOrthoVW, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout - = ColOrthoVW; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } + }; + } - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } + //////////////////////////////////// + /// MatrixLayout specializations /// + //////////////////////////////////// - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } + // Matrix layout matching test criteria are if all parameters match, with some flexibility in VectorWidth. + template + class MatrixLayout> + struct is_layout_same, + MatrixLayout> + : public integral_constant + { + }; - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } + // Matrix layout transpose test with flexibility in the VectorWidth. + // Transposed matrix layouts swap matrix space rows / cols. + template + struct is_layout_transpose< + MatrixLayout::template ColOrthoVW, + MatrixLayout::template RowOrthoVW> + : public integral_constant + { + }; - ROCWMMA_DEVICE static inline auto debug() {} - }; + template + struct is_layout_transpose< + MatrixLayout::template RowOrthoVW, + MatrixLayout::template ColOrthoVW> + : public integral_constant + { + }; - } // namespace MatrixLayout + template + struct is_layout_transpose< + MatrixLayout::template ColInlineVW, + MatrixLayout::template RowInlineVW> + : public integral_constant + { + }; - template - struct RegisterLayoutOfMatrix; + template + struct is_layout_transpose< + MatrixLayout::template RowInlineVW, + MatrixLayout::template ColInlineVW> + : public integral_constant + { + }; + // Matrix space transpose guide: Swap rows / cols + // VW stays consistent. template - struct RegisterLayoutOfMatrix< + struct layout_transpose< MatrixLayout::template ColOrthoVW> { - using Type = RegisterLayout::template Soa; + using type = MatrixLayout:: + template RowOrthoVW; }; template - struct RegisterLayoutOfMatrix< - MatrixLayout::template ColInlineVW> + struct layout_transpose< + MatrixLayout::template RowOrthoVW> { - using Type = RegisterLayout::template Aos; + using type = MatrixLayout:: + template ColOrthoVW; }; template - struct RegisterLayoutOfMatrix< - MatrixLayout::template RowOrthoVW> + struct layout_transpose< + MatrixLayout::template ColInlineVW> { - using Type = RegisterLayout::template Soa; + using type = MatrixLayout:: + template RowInlineVW; }; template - struct RegisterLayoutOfMatrix< + struct layout_transpose< MatrixLayout::template RowInlineVW> { - using Type = RegisterLayout::template Aos; + using type = MatrixLayout:: + template ColInlineVW; }; - /// - /// Helper to obtain orthogonal data layout - /// - - // Data Layouts + /////////////////////////////////////// + /// Register layout specializations /// + /////////////////////////////////////// - template <> - struct OrthogonalLayout + // Register layouts are the same if all test parameters match, with some flexibility in VectorWidth. + template + class MatrixLayout> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout>, + RegisterLayout::template Storage< + MatrixLayout>> + : public integral_constant { - using Type = col_major; }; - template <> - struct OrthogonalLayout + // ColOrthoVW and RowOrthoVW layouts are already in mma input format for mma sized BlockDim (16 or 32) + template + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>, + RegisterLayout::template MmaInput> + : public integral_constant { - using Type = row_major; }; - template - struct OrthogonalLayout> + template + struct is_layout_same< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>> + : public integral_constant { - using Type = DataLayout::template Array1d::Type>; }; - // Matrix Layouts template - struct OrthogonalLayout< - MatrixLayout::template ColOrthoVW> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>, + RegisterLayout::template MmaInput> + : public integral_constant { - using Type = MatrixLayout:: - template RowOrthoVW; }; template - struct OrthogonalLayout< - MatrixLayout::template ColInlineVW> + struct is_layout_same< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>> + : public integral_constant { - using Type = MatrixLayout:: - template RowInlineVW; }; + // TODO: necessary? + // In-register layouts for transposed RowOrthoVW / ColOrthoVW and RowInlineVW / ColInline are technically 'the same' + // for each thread, even though the data interpretation is different (e.g., row elements vs col elements). template - struct OrthogonalLayout< - MatrixLayout::template RowOrthoVW> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>> + : public integral_constant { - using Type = MatrixLayout:: - template ColOrthoVW; }; template - struct OrthogonalLayout< - MatrixLayout::template RowInlineVW> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>> + : public integral_constant { - using Type = MatrixLayout:: - template ColInlineVW; }; template - struct OrthogonalLayout> + uint32_t VectorWidthLhs, + uint32_t VectorWidthRhs, + uint32_t MaxVectorWidth> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>> + : public integral_constant { - using Type = MatrixLayout::template RowOrthoInt; }; + template + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>> + : public integral_constant + { + }; + + // ColInlineVW and RowInlineVW layouts are transposed to mma input format for mma sized BlockDim (16 or 32) template - struct OrthogonalLayout> + uint32_t MaxVectorWidth> + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>, + RegisterLayout::template MmaInput> + : public integral_constant { - using Type = MatrixLayout::template RowInlineInt; }; template - struct OrthogonalLayout> + uint32_t MaxVectorWidth> + struct is_layout_transpose< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>> + : public integral_constant { - using Type = MatrixLayout::template ColOrthoInt; }; template - struct OrthogonalLayout> + uint32_t MaxVectorWidth> + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>, + RegisterLayout::template MmaInput> + : public integral_constant { - using Type = MatrixLayout::template ColInlineInt; }; - // Register layouts - template - struct OrthogonalLayout> + template + struct is_layout_transpose< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>> + : public integral_constant { - using Type = RegisterLayout::template Soa; }; - template - struct OrthogonalLayout> + // In-register layouts for (ColInlineVW / ColOrthoVW) and (RowInlineVW / RowOrthoVW) are the orthogonal register transposes. + template + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>> + : public integral_constant { - using Type = RegisterLayout::template Aos; }; - /// - /// Helper to check if layout types are orthogonal - /// + template + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>> + : public integral_constant + { + }; - // In general, assume that an orthogonal layout has been assigned - template - struct is_orthogonal + template + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>> : public integral_constant, RhsDataLayout>> + detail::testSupportedVW( + MaxVectorWidth, VectorWidthLhs, VectorWidthRhs)> { }; - // Special case for self: not orthogonal - template - struct is_orthogonal : public false_type + template + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>> + : public integral_constant { }; - // Special cases for MatrixLayouts, the VectorWidth used does not matter in determining orthogonality, however all other properties must match. + // In-register layouts for (RowOrthoVW / ColInlineVW) and (ColOrthoVW / RowInlineVW) are the orthogonal register transposes. template - struct is_orthogonal< - MatrixLayout::template ColOrthoVW, - MatrixLayout::template RowOrthoVW> - : public true_type + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>> + : public integral_constant { }; template - struct is_orthogonal< - MatrixLayout::template RowOrthoVW, - MatrixLayout::template ColOrthoVW> - : public true_type + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template ColInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowOrthoVW>> + : public integral_constant { }; template - struct is_orthogonal< - MatrixLayout::template ColInlineVW, - MatrixLayout::template RowInlineVW> - : public true_type + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>> + : public integral_constant { }; template - struct is_orthogonal< - MatrixLayout::template RowInlineVW, - MatrixLayout::template ColInlineVW> - : public true_type + struct is_layout_transpose< + RegisterLayout::template Storage< + MatrixLayout:: + template RowInlineVW>, + RegisterLayout::template Storage< + MatrixLayout:: + template ColOrthoVW>> + : public integral_constant { }; } // namespace rocwmma -#endif // ROCWMMA_LAYOUT_IMPL_HPP +#endif // ROCWMMA_MATRIX_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp new file mode 100644 index 00000000..b7f5be33 --- /dev/null +++ b/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp @@ -0,0 +1,1348 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP +#define ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + + // Implementations for the interleaved MatrixLayout classes + namespace MatrixLayout + { + template // # of splits + struct ColInlineInt + { + using IOTraits = IOTraits; + struct Traits + { + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 0u, + BlockKStride_Y = 1u, + + VWStride_X = VectorWidth, + VWStride_Y = 0u, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = KPerThread / BlockKStride_Y, + VWSegs = DimPerThread / VWStride_X, + }; + + // Check VectorWidth validity + static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); + static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, + "DimPerThread not a multiple of VectorWidth"); + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowInlineInt; + + using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + + return make_vector((uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + ROCWMMA_DEVICE static inline auto debug() + { + if(threadIdx.x == 0 && threadIdx.y == 0) + { + printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + (uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); + + printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " + "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + (uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y, + (uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y, + (uint32_t)Traits::VWStride_X, + (uint32_t)Traits::VWStride_Y); + } + if(threadIdx.x <= 63 && threadIdx.y == 0) + { + printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", + threadIdx.x, + get<0>(baseOffset()), + get<1>(baseOffset())); + } + } + }; + + template // # of splits + struct ColOrthoInt + { + using IOTraits = IOTraits; + struct Traits + { + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 1u, + BlockKStride_Y = 0u, + + VWStride_X = 0u, + VWStride_Y = VectorWidth, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = DimPerThread / BlockKStride_X, + VWSegs = KPerThread / VWStride_Y, + }; + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // Check VectorWidth validity + static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); + static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, + "KPerThread not a multiple of VectorWidth"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowOrthoInt; + + using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments + (uint32_t)Traits::BlockKSegs, // BlockK Segments + (uint32_t)Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto debug() + { + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); + // } + } + }; + + template + struct RowInlineInt + { + // RowInlineInt is orthogonal to ColInlineInt, therefore we can use reversed coordinates + struct Traits + { + using OrthoLayout = ColInlineInt; + + using MatrixCoordT = Coord2d; + }; + + // Matrix coord offsets + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return swap(Traits::OrthoLayout::baseOffset()); + } + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return Traits::OrthoLayout::strideCounts(); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + auto t = Traits::OrthoLayout::strides(); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + } + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto debug() + { + Traits::OrthoLayout::debug(); + } + }; + + template + struct RowOrthoInt + { + // RowOrthoInt is orthogonal to ColOrthoInt, therefore we can use reversed coordinates + struct Traits + { + using OrthoLayout = ColOrthoInt; + + using MatrixCoordT = Coord2d; + }; + + // Matrix coord offsets + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return swap(Traits::OrthoLayout::baseOffset()); + } + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return Traits::OrthoLayout::strideCounts(); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + auto t = Traits::OrthoLayout::strides(); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + } + + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto debug() {} + }; + + } // namespace MatrixLayout + + //////////////////////////////////// + /// MatrixLayout specializations /// + //////////////////////////////////// + + // Matrix layout matching test criteria are if all parameters match, with some flexibility in VectorWidth. + template + class MatrixLayout> + struct is_layout_same< + MatrixLayout, + MatrixLayout> + : public integral_constant + { + }; + + // Matrix layout transpose test with flexibility in the VectorWidth. + // Transposed matrix layouts swap matrix space rows / cols. + template + struct is_layout_transpose, + MatrixLayout::template RowOrthoInt> + : public integral_constant + { + }; + + template + struct is_layout_transpose, + MatrixLayout::template ColOrthoInt> + : public integral_constant + { + }; + + template + struct is_layout_transpose, + MatrixLayout::template RowInlineInt> + : public integral_constant + { + }; + + template + struct is_layout_transpose, + MatrixLayout::template ColInlineInt> + : public integral_constant + { + }; + + // Matrix space transpose guide: Swap rows / cols + // VW stays consistent. + template + struct layout_transpose> + { + using type = MatrixLayout::template RowOrthoInt; + }; + + template + struct layout_transpose> + { + using type = MatrixLayout::template ColOrthoInt; + }; + + template + struct layout_transpose> + { + using type = MatrixLayout::template RowInlineInt; + }; + + template + struct layout_transpose> + { + using type = MatrixLayout::template ColInlineInt; + }; + + /////////////////////////////////////// + /// Register layout specializations /// + /////////////////////////////////////// + + // Register layouts are the same if all test parameters match, with some flexibility in VectorWidth. + template + class MatrixLayout> + struct is_layout_same< + RegisterLayout::template Storage< + MatrixLayout>, + RegisterLayout::template Storage< + MatrixLayout>> + : public integral_constant + { + }; + + // ColOrthoInt and RowOrthoInt layouts are already in mma input format for mma sized BlockDim (16 or 32) + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template MmaInput> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template MmaInput> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + // TODO: necessary? + // In-register layouts for transposed RowOrthoInt / ColOrthoInt and RowInlineInt / ColInline are technically 'the same' + // for each thread, even though the data interpretation is different (e.g., row elements vs col elements). + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_same< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + // ColInlineInt and RowInlineInt layouts are transposed to mma input format for mma sized BlockDim (16 or 32) + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template MmaInput> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template MmaInput> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template MmaInput, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + // In-register layouts for (ColInlineInt / ColOrthoInt) and (RowInlineInt / RowOrthoInt) are the orthogonal register transposes. + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + // In-register layouts for (RowOrthoInt / ColInlineInt) and (ColOrthoInt / RowInlineInt) are the orthogonal register transposes. + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + + template + struct is_layout_transpose< + RegisterLayout::template Storage>, + RegisterLayout::template Storage>> + : public integral_constant + { + }; + +} // namespace rocwmma + +#endif // ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_impl.hpp new file mode 100644 index 00000000..9118f08f --- /dev/null +++ b/library/include/rocwmma/internal/layout/register_layout_impl.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_REGISTER_LAYOUT_IMPL_HPP +#define ROCWMMA_REGISTER_LAYOUT_IMPL_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" +#include "utility/type_traits.hpp" + +namespace rocwmma +{ + // Use generic MatrixLayout transpose rules to guide the register layout transpose suggestion + template + struct layout_transpose> + { + using type = RegisterLayout::template Storage>; + }; + +} // namespace rocwmma + +#endif // ROCWMMA_REGISTER_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/transforms.hpp b/library/include/rocwmma/internal/transforms.hpp index acbc4c69..1844ca62 100644 --- a/library/include/rocwmma/internal/transforms.hpp +++ b/library/include/rocwmma/internal/transforms.hpp @@ -59,10 +59,189 @@ namespace rocwmma template using AosToSoa = Driver>; - + template using SoaToAos = Driver>; - + + // Note: If you arrive at an undefined register_transform error, it is likely + // the layout transformation is not currently supported. Need to either implement + // the transform or ensure your layout transform mapping is correct. + template > + struct register_transform; + + // Layouts that are identical do not require register transformations + template + struct register_transform + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) + exec(VecT const& v) + { + return v; + } + }; + + /////// To MmaInput /////// + + // ColInlineVW and RowInlineVW layouts are not mma friendly and require Aos->Soa transform. + // Only valid for BlockDims that supported by mma + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::ColInlineVW>, + RegisterLayout::MmaInput, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(RegisterLayout::detail::testSupportedMmaDim(BlockDim), + "Unsupported mma dim"); + + // ColInlineVW -> ColOrthoVW (mma friendly) = AOS -> SOA transform + return AosToSoa::exec(v); + } + }; + + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::RowInlineVW>, + RegisterLayout::MmaInput, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(RegisterLayout::detail::testSupportedMmaDim(BlockDim), + "Unsupported mma dim"); + + // RowInlineVW -> RowOrthoVW (mma friendly) = AOS -> SOA transform + return AosToSoa::exec(v); + } + }; + + /////// To Other Layouts /////// + + // In-register layouts for (RowInlineVW and RowOrthoVW), and (ColInlineVW and ColOrthoVW) are orthgonal + // and need specific transforms to transition between either representation. + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::RowInlineVW>, + RegisterLayout::Storage< + MatrixLayout::RowOrthoVW>, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(RegisterLayout::detail::testSupportedVW( + MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), + "Invalid VW"); + + // RowInlineVW -> RowOrthoVW = AOS -> SOA transform + return AosToSoa::exec(v); + } + }; + + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::RowOrthoVW>, + RegisterLayout::Storage< + MatrixLayout::RowInlineVW>, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(RegisterLayout::detail::testSupportedVW( + MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), + "Invalid VW"); + + // RowOrthoVW -> RowInlineVW = SOA -> AOS transform + return SoaToAos::exec(v); + } + }; + + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::ColInlineVW>, + RegisterLayout::Storage< + MatrixLayout::ColOrthoVW>, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(RegisterLayout::detail::testSupportedVW( + MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), + "Invalid VW"); + + // ColInlineVW -> ColOrthoVW = AOS -> SOA transform + return AosToSoa::exec(v); + } + }; + + template + struct register_transform< + RegisterLayout::Storage< + MatrixLayout::ColOrthoVW>, + RegisterLayout::Storage< + MatrixLayout::ColInlineVW>, + false_type> + { + // TODO: Remove DataT from the transform + template + ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) + { + static_assert(0, "Nope"); + static_assert(RegisterLayout::detail::testSupportedVW( + MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), + "Invalid VW"); + + // ColOrthoVW -> ColInlineVW = SOA -> AOS transform + return SoaToAos::exec(v); + } + }; + } // namespace Transforms } // namespace rocwmma diff --git a/library/include/rocwmma/internal/transforms_impl.hpp b/library/include/rocwmma/internal/transforms_impl.hpp index e9e4ebc7..97981e2e 100644 --- a/library/include/rocwmma/internal/transforms_impl.hpp +++ b/library/include/rocwmma/internal/transforms_impl.hpp @@ -213,7 +213,6 @@ namespace rocwmma { namespace Ops { - template struct AosToSoa { diff --git a/library/include/rocwmma/internal/vector_util.hpp b/library/include/rocwmma/internal/vector_util.hpp index abc56c4f..bd6a97c4 100644 --- a/library/include/rocwmma/internal/vector_util.hpp +++ b/library/include/rocwmma/internal/vector_util.hpp @@ -122,6 +122,17 @@ namespace rocwmma template ROCWMMA_DEVICE constexpr static inline auto unpackHi(VecT const& v0, VecT const& v1); + + //! Interleaves elements from the vector, according to group size + //! E.g. GroupSize = 4 + //! v0 = [0, 1, 2, 3, 4, 5, 6, 7] + //! result = [0, 4, 1, 5, 2, 6, 3, 7] + /*! + \param v0 Vector from which interleaved elements are selected from + */ + template + ROCWMMA_DEVICE constexpr static inline auto interleave(VecT const& v0); + } // namespace rocwmma #include "vector_util_impl.hpp" diff --git a/library/include/rocwmma/internal/vector_util_impl.hpp b/library/include/rocwmma/internal/vector_util_impl.hpp index e0115a46..1096757e 100644 --- a/library/include/rocwmma/internal/vector_util_impl.hpp +++ b/library/include/rocwmma/internal/vector_util_impl.hpp @@ -416,6 +416,20 @@ namespace rocwmma } } + template + ROCWMMA_DEVICE constexpr static inline auto interleave(VecT const& v0) + { + // Interleave groups + auto offset = [](auto&& idx, auto&& v0) { + constexpr auto Index = decay_t::value; + constexpr auto Offset0 = Index * GroupSize; + constexpr auto Offset1 = Index / (VecSize / GroupSize); + return get<(Offset0 + Offset1) % VecSize>(v0); + }; + + return vector_generator()(offset, v0); + } + } // namespace rocwmma #endif // ROCWMMA_VECTOR_UTIL_IMPL_HPP diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index deea2b45..b53b619c 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -349,10 +349,10 @@ namespace rocwmma typename IOConfigB::IOLayout::RegisterLayout>, "Input fragment register layouts do not match"); - static_assert(is_same_v>, - "Input fragment register layouts are not mfma friendly"); + // static_assert(is_same_v>, + // "Input fragment register layouts are not mfma friendly"); // Gfx9 uses MFMA, gfx11 uses WMMA using MMA = conditional_t struct ApplyTranspose; @@ -65,7 +65,9 @@ namespace rocwmma // Original frag A type using FragA = fragment; - // Transpose to frag B type in opposite data layout. + // Transpose to frag B type in opposite data layout: + // - Exchange Block M for BlockN + // - Exchange row_major for col_major and vice-versa using FragB = fragment, @@ -115,10 +117,12 @@ namespace rocwmma struct ApplyTranspose> { private: - // Original frag A type + // Original frag B type using FragB = fragment; - // Transpose to frag A type in opposite data layout. + // Transpose to frag A type in opposite data layout: + // - Exchange Block M for BlockN + // - Exchange row_major for col_major and vice-versa using FragA = fragment && is_same_v, - int> = 0> + template + && is_same_v, + int> + = 0> ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(FragT const& frag) { return reinterpret_cast(frag); } // Input and output register layouts do not match: must transform using AOS<->SOA - template < - uint32_t WaveCount = 1, - typename FragT, - enable_if_t< - is_same_v && !is_same_v, - int> = 0> + template + && !is_same_v, + int> + = 0> ROCWMMA_DEVICE constexpr static inline auto exec(FragT const& frag) { // TODO: Make sure to use coop configs to get the right MaxVW!!! - using IOConfigCoop = GetCoopIOConfig_t; - constexpr uint32_t BlockDim = IOConfigCoop::IOShape::BlockDim; - constexpr uint32_t MaxVW = IOConfigCoop::IOLayout::MaxVW; - using RegisterLayoutIncoming = typename IOConfigCoop::IOLayout::RegisterLayout; - - // Target layouts - using AosLayout = RegisterLayout::template Aos; - using SoaLayout = RegisterLayout::template Soa; - - auto result = FragOut{}; - - if constexpr(is_same_v) - { - result.mAccess = Transforms::AosToSoa::exec(frag.mAccess); - } - else if constexpr(is_same_v) - { - result.mAccess = Transforms::SoaToAos::exec(frag.mAccess); - } + // using IOConfigCoopIn = GetCoopIOConfig_t; + // constexpr uint32_t BlockDim = IOConfigCoop::IOShape::BlockDim; + // constexpr uint32_t MaxVW = IOConfigCoop::IOLayout::MaxVW; + // using RegisterLayoutIncoming = typename IOConfigCoop::IOLayout::RegisterLayout; + + // // Target layouts + // using AosLayout = RegisterLayout::template Aos; + // using SoaLayout = RegisterLayout::template Soa; + + using SrcRegLayout = + typename GetCoopIOConfig_t::IOLayout::RegisterLayout; + using DstRegLayout = + typename GetCoopIOConfig_t::IOLayout::RegisterLayout; + + auto result = FragOut{ + Transforms::register_transform::exec(frag.mAccess)}; + //result.mAccess = ; + + // if constexpr(is_same_v) + // { + // result.mAccess = Transforms::AosToSoa::exec(frag.mAccess); + // } + // else if constexpr(is_same_v) + // { + // result.mAccess = Transforms::SoaToAos::exec(frag.mAccess); + // } return result; } From cdd13bc503e25b2a941a109286281b513662fa80 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 12 Sep 2024 20:23:04 +0000 Subject: [PATCH 04/36] Refactor layout and traits organization --- .../internal/layout/data_layout_impl.hpp | 79 - .../internal/layout/data_layout_traits.hpp | 104 ++ .../rocwmma/internal/layout/layout.hpp | 15 +- .../rocwmma/internal/layout/layout_traits.hpp | 44 +- .../internal/layout/layout_traits_impl.hpp | 42 +- .../internal/layout/matrix_layout_impl.hpp | 947 +++++------- .../layout/matrix_layout_interleaved_impl.hpp | 1305 ----------------- .../internal/layout/matrix_layout_traits.hpp | 445 ++++++ .../internal/layout/register_layout_impl.hpp | 44 - .../layout/register_layout_traits.hpp | 352 +++++ 10 files changed, 1326 insertions(+), 2051 deletions(-) delete mode 100644 library/include/rocwmma/internal/layout/data_layout_impl.hpp create mode 100644 library/include/rocwmma/internal/layout/data_layout_traits.hpp create mode 100644 library/include/rocwmma/internal/layout/matrix_layout_traits.hpp delete mode 100644 library/include/rocwmma/internal/layout/register_layout_impl.hpp create mode 100644 library/include/rocwmma/internal/layout/register_layout_traits.hpp diff --git a/library/include/rocwmma/internal/layout/data_layout_impl.hpp b/library/include/rocwmma/internal/layout/data_layout_impl.hpp deleted file mode 100644 index eccbb67a..00000000 --- a/library/include/rocwmma/internal/layout/data_layout_impl.hpp +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_DATA_LAYOUT_IMPL_HPP -#define ROCWMMA_DATA_LAYOUT_IMPL_HPP - -#include "layout.hpp" -#include "layout_traits.hpp" -#include "utility/type_traits.hpp" - -namespace rocwmma -{ - // Data layout trait tags are transposes - template <> - struct is_layout_transpose : public true_type - { - }; - - template <> - struct is_layout_transpose : public true_type - { - }; - - // Data layout objects are transposes - template <> - struct is_layout_transpose : public true_type - { - }; - - template <> - struct is_layout_transpose : public true_type - { - }; - - // Data layout trait tag transpose - template <> - struct layout_transpose - { - using type = col_major; - }; - - template <> - struct layout_transpose - { - using type = row_major; - }; - - // Data layout object type transpose - template - struct layout_transpose> - { - using Type = DataLayout::template Array1d>; - }; - -} // namespace rocwmma - -#endif // ROCWMMA_DATA_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/data_layout_traits.hpp b/library/include/rocwmma/internal/layout/data_layout_traits.hpp new file mode 100644 index 00000000..6fcb50b0 --- /dev/null +++ b/library/include/rocwmma/internal/layout/data_layout_traits.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_DATA_LAYOUT_TRAITS_HPP +#define ROCWMMA_DATA_LAYOUT_TRAITS_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + namespace LayoutTraits_impl + { + // Sameness classifier + template <> + struct is_layout_same : public true_type + { + }; + + template <> + struct is_layout_same : public true_type + { + }; + + template <> + struct is_layout_same : public true_type + { + }; + + template <> + struct is_layout_same : public true_type + { + }; + + // Orthogonality classifier + template <> + struct is_layout_orthogonal : public true_type + { + }; + + template <> + struct is_layout_orthogonal : public true_type + { + }; + + template <> + struct is_layout_orthogonal + : public true_type + { + }; + + template <> + struct is_layout_orthogonal + : public true_type + { + }; + + // Orthogonal layout guides + template <> + struct orthogonal_layout + { + using type = col_major; + }; + + template <> + struct orthogonal_layout + { + using type = row_major; + }; + + template + struct orthogonal_layout> + { + using Type + = DataLayout::template Array1d::type>; + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#endif // ROCWMMA_DATA_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index 0c30d615..a50d89cf 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -26,6 +26,7 @@ #ifndef ROCWMMA_LAYOUT_HPP #define ROCWMMA_LAYOUT_HPP +#include "api_fwd.hpp" #include "mapping_util.hpp" namespace rocwmma @@ -101,12 +102,9 @@ namespace rocwmma uint32_t MaxVectorWidth> struct RowInlineVW; - /////////////////// Interleaved patterns ////////////////// template // # of splits struct ColInlineInt; @@ -114,8 +112,6 @@ namespace rocwmma template // # of splits struct ColOrthoInt; @@ -123,8 +119,6 @@ namespace rocwmma template // # of splits struct RowInlineInt; @@ -132,14 +126,10 @@ namespace rocwmma template // # of splits struct RowOrthoInt; - /////////////////// ////////////////////////////// ////////////////// - } // namespace MatrixLayout // Register layouts describe in-register layout and serve as transform states, or endpoints. @@ -176,9 +166,6 @@ namespace rocwmma } // namespace rocwmma -#include "data_layout_impl.hpp" #include "matrix_layout_impl.hpp" -#include "matrix_layout_interleaved_impl.hpp" -#include "register_layout_impl.hpp" #endif // ROCWMMA_LAYOUT_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp index 39c1c161..60002937 100644 --- a/library/include/rocwmma/internal/layout/layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -26,7 +26,10 @@ #ifndef ROCWMMA_LAYOUT_TRAITS_HPP #define ROCWMMA_LAYOUT_TRAITS_HPP -#include "utility/type_traits.hpp" +#include "data_layout_traits.hpp" +#include "layout_traits_impl.hpp" +#include "matrix_layout_traits.hpp" +#include "register_layout_traits.hpp" namespace rocwmma { @@ -39,39 +42,54 @@ namespace rocwmma * @tparam RhsLayout Comparative right hand side */ template - struct is_layout_same : public false_type + struct is_layout_same : public LayoutTraits_impl::is_layout_same { }; - /*! \class is_layout_transpose - * \brief Compares layout types if they are transposed with each other. + /*! \class is_layout_orthogonal + * \brief Compares layout types if they are orthogonal with each other. * @tparam LhsLayout Comparative left hand side * @tparam RhsLayout Comparative right hand side */ template - struct is_layout_transpose : public false_type + struct is_layout_orthogonal + : public LayoutTraits_impl::is_layout_orthogonal { }; - /*! \class layout_transpose - * \brief Transforms the layout type into its direct transpose. + /*! \class orthogonal_layout + * \brief Transforms the layout type into its orthogonal layout. * @tparam Layout the layout to transpose from */ template - struct layout_transpose + struct orthogonal_layout : public LayoutTraits_impl::orthogonal_layout { - // using type = ... }; + /*! \class is_layout_same_v + * \brief Evaluates is_layout_same + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + constexpr static inline bool is_layout_same_v = is_layout_same::value; + + /*! \class is_layout_orthogonal + * \brief Evaluates is_layout_orthogonal + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + constexpr static inline bool is_layout_orthogonal_v + = is_layout_orthogonal::value; + /*! \class layout_transpose_t - * \brief Transforms the layout type into its direct transpose. + * \brief Transforms the layout type into its orthogonal layout. * @tparam Layout the layout to transpose from */ template - using layout_transpose_t = typename layout_transpose::type; + using orthogonal_layout_t = typename orthogonal_layout::type; } // namespace rocwmma -#include "layout_traits_impl.hpp" - #endif // ROCWMMA_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp index 0bcf3a1f..b7647aa1 100644 --- a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp @@ -26,43 +26,29 @@ #ifndef ROCWMMA_LAYOUT_TRAITS_IMPL_HPP #define ROCWMMA_LAYOUT_TRAITS_IMPL_HPP -#include "config.hpp" -#include "layout_traits.hpp" +#include "utility/type_traits.hpp" namespace rocwmma { - // Common helpers for supported traits - namespace detail + namespace LayoutTraits_impl { - // Based on the current config, determine the compatibility of the mma dimension - constexpr static inline bool testSupportedMmaDim(uint32_t testDim) + // Classifier to test layout sameness + template + struct is_layout_same : public false_type { - return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && testDim == 16u) - || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && (testDim == 16u || testDim == 32u)); - } + }; - // VW can be changed from vw0 to vw1 as long as they have the same maxVW, and that maxVW - // is a multiple of both vw values - constexpr static inline bool testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) + // Classifer to test layout orthogonality + template + struct is_layout_orthogonal : public false_type { - return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); - } + }; - } // namespace detail + // Orthogonality guide + template + struct orthogonal_layout; - // Covers all other generic exact layout class matches - - // Self-compare is always true - template - struct is_layout_same : public true_type - { - }; - - // Self-compare is always false - template - struct is_layout_transpose : public false_type - { - }; + } // namespace LayoutTraits_impl } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp index 430321e3..8710c8ba 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -529,68 +529,414 @@ namespace rocwmma template - struct RowInlineVW + uint32_t MfmaDim, // MFMA instruction size + uint32_t SplitK /* = 1*/> // # of splits + struct ColInlineInt { - // RowInlineVW is orthogonal to ColInlineVW, therefore we can use reversed coordinates + using IOTraits = IOTraits; struct Traits { - using OrthoLayout - = ColInlineVW; + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 0u, + BlockKStride_Y = 1u, + + VWStride_X = DimPerThread, + VWStride_Y = 0u, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = KPerThread / BlockKStride_Y, + VWSegs = DimPerThread / VWStride_X, + }; + + // // Check VectorWidth validity + // static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); + // static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, + // "DimPerThread not a multiple of VectorWidth"); + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowInlineInt; using MatrixCoordT = Coord2d; }; - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - return swap(Traits::OrthoLayout::baseOffset()); + + return make_vector((uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); } - ROCWMMA_DEVICE constexpr static inline auto strideCounts() + ROCWMMA_DEVICE constexpr static inline auto strides() { - return Traits::OrthoLayout::strideCounts(); + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); } - ROCWMMA_DEVICE constexpr static inline auto strides() + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() { - auto t = Traits::OrthoLayout::strides(); - return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); } + // Incremental iteration offset ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT incrementalOffset(uint32_t iteration) { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); } + + // Cumulative iteration offset ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT cumulativeOffset(uint32_t iteration) { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; - ROCWMMA_DEVICE static inline auto debug() {} + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + ROCWMMA_DEVICE static inline auto debug() + { + if(threadIdx.x == 0 && threadIdx.y == 0) + { + printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + (uint32_t)Traits::SplitKSegs, + (uint32_t)Traits::BlockKSegs, + (uint32_t)Traits::VWSegs); + + printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " + "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + (uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y, + (uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y, + (uint32_t)Traits::VWStride_X, + (uint32_t)Traits::VWStride_Y); + } + if(threadIdx.x <= 63 && threadIdx.y == 0) + { + printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", + threadIdx.x, + get<0>(baseOffset()), + get<1>(baseOffset())); + } + } }; template - struct RowOrthoVW + uint32_t MfmaDim, // MFMA instruction size + uint32_t SplitK /*= 1*/> // # of splits + struct ColOrthoInt { - // RowOrthoVW is orthogonal to ColOrthoVW, therefore we can use reversed coordinates + using IOTraits = IOTraits; struct Traits { - using OrthoLayout - = ColOrthoVW; + enum : uint32_t + { + // Number of threads per wave + WaveSize = IOTraits::ThreadsPerIO, + + // Number of elements each thread will fetch in BlockDim direction + DimPerThread = BlockDim / MfmaDim, + + // Number of elements each thread will fetch in BlockK direction + KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + + // Number of elements that each thread is responsible for + ElementsPerThread = DimPerThread * KPerThread, + + // Strides + SplitKStride_X = 0u, + SplitKStride_Y = BlockK / SplitK, + + BlockKStride_X = 1u, + BlockKStride_Y = 0u, + + VWStride_X = 0u, + VWStride_Y = DimPerThread, + + // Stride Space + SplitKSegs = BlockK / SplitKStride_Y, + BlockKSegs = DimPerThread / BlockKStride_X, + VWSegs = KPerThread / VWStride_Y, + }; + + // Check KPerThread validity + static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); + static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, + "BlockK is not a multiple of KPerThread"); + + // // Check VectorWidth validity + // static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); + // static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, + // "KPerThread not a multiple of VectorWidth"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + using OrthoLayout = RowOrthoInt; using MatrixCoordT = Coord2d; }; - // Matrix coord offsets + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments + (uint32_t)Traits::BlockKSegs, // BlockK Segments + (uint32_t)Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector( + make_coord2d((uint32_t)Traits::SplitKStride_X, + (uint32_t)Traits::SplitKStride_Y), + make_coord2d((uint32_t)Traits::BlockKStride_X, + (uint32_t)Traits::BlockKStride_Y), + make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + } + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + { + return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) + % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT + cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto debug() + { + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); + // } + } + }; + + template + struct OrthoImpl + { + struct Traits + { + using OrthoLayout = orthogonal_layout_t; + }; + + // Matrix coord offsets + ROCWMMA_DEVICE static inline typename auto baseOffset() { return swap(Traits::OrthoLayout::baseOffset()); } @@ -606,13 +952,12 @@ namespace rocwmma return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { return swap(Traits::OrthoLayout::incrementalOffset(iteration)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) + + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); } @@ -620,546 +965,12 @@ namespace rocwmma ROCWMMA_DEVICE static inline auto debug() {} }; - } // namespace MatrixLayout - - namespace detail - { - template - struct is_ColOrthoVW : public false_type - { - }; - - template - struct is_ColOrthoVW< - MatrixLayout::template ColOrthoVW> - : public true_type - { - }; - - template - struct is_RowOrthoVW : public false_type - { - }; + using RowOrthoVW = OrthoImpl; + using RowInlineVW = OrthoImpl; + using RowOrthoInt = OrthoImpl; + using RowInlineInt = OrthoImpl; - template - struct is_RowOrthoVW< - MatrixLayout::template RowOrthoVW> - : public true_type - { - }; - } - - //////////////////////////////////// - /// MatrixLayout specializations /// - //////////////////////////////////// - - // Matrix layout matching test criteria are if all parameters match, with some flexibility in VectorWidth. - template - class MatrixLayout> - struct is_layout_same, - MatrixLayout> - : public integral_constant - { - }; - - // Matrix layout transpose test with flexibility in the VectorWidth. - // Transposed matrix layouts swap matrix space rows / cols. - template - struct is_layout_transpose< - MatrixLayout::template ColOrthoVW, - MatrixLayout::template RowOrthoVW> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - MatrixLayout::template RowOrthoVW, - MatrixLayout::template ColOrthoVW> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - MatrixLayout::template ColInlineVW, - MatrixLayout::template RowInlineVW> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - MatrixLayout::template RowInlineVW, - MatrixLayout::template ColInlineVW> - : public integral_constant - { - }; - - // Matrix space transpose guide: Swap rows / cols - // VW stays consistent. - template - struct layout_transpose< - MatrixLayout::template ColOrthoVW> - { - using type = MatrixLayout:: - template RowOrthoVW; - }; - - template - struct layout_transpose< - MatrixLayout::template RowOrthoVW> - { - using type = MatrixLayout:: - template ColOrthoVW; - }; - - template - struct layout_transpose< - MatrixLayout::template ColInlineVW> - { - using type = MatrixLayout:: - template RowInlineVW; - }; - - template - struct layout_transpose< - MatrixLayout::template RowInlineVW> - { - using type = MatrixLayout:: - template ColInlineVW; - }; - - /////////////////////////////////////// - /// Register layout specializations /// - /////////////////////////////////////// - - // Register layouts are the same if all test parameters match, with some flexibility in VectorWidth. - template - class MatrixLayout> - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout>, - RegisterLayout::template Storage< - MatrixLayout>> - : public integral_constant - { - }; - - // ColOrthoVW and RowOrthoVW layouts are already in mma input format for mma sized BlockDim (16 or 32) - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>> - : public integral_constant - { - }; - - // TODO: necessary? - // In-register layouts for transposed RowOrthoVW / ColOrthoVW and RowInlineVW / ColInline are technically 'the same' - // for each thread, even though the data interpretation is different (e.g., row elements vs col elements). - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>> - : public integral_constant - { - }; - - // ColInlineVW and RowInlineVW layouts are transposed to mma input format for mma sized BlockDim (16 or 32) - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>> - : public integral_constant - { - }; - - // In-register layouts for (ColInlineVW / ColOrthoVW) and (RowInlineVW / RowOrthoVW) are the orthogonal register transposes. - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>> - : public integral_constant - { - }; - - // In-register layouts for (RowOrthoVW / ColInlineVW) and (ColOrthoVW / RowInlineVW) are the orthogonal register transposes. - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template ColInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowOrthoVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage< - MatrixLayout:: - template RowInlineVW>, - RegisterLayout::template Storage< - MatrixLayout:: - template ColOrthoVW>> - : public integral_constant - { - }; + } // namespace MatrixLayout } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp index b7f5be33..e866017b 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp @@ -35,1314 +35,9 @@ namespace rocwmma // Implementations for the interleaved MatrixLayout classes namespace MatrixLayout { - template // # of splits - struct ColInlineInt - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, - - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), - - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, - - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, - - BlockKStride_X = 0u, - BlockKStride_Y = 1u, - - VWStride_X = VectorWidth, - VWStride_Y = 0u, - - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = KPerThread / BlockKStride_Y, - VWSegs = DimPerThread / VWStride_X, - }; - - // Check VectorWidth validity - static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); - static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, - "DimPerThread not a multiple of VectorWidth"); - - // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); - - // Check SplitK validity - static_assert(BlockK >= SplitK, "Invalid SplitK"); - static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); - - // Check MfmaDim validity - static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); - static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowInlineInt; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - - return make_vector((uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); - } - - // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetX = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetX = (int32_t)Traits::VWStride_X; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::SplitKSegs > 1)) - { - // "Reset" cycle - VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::SplitKStride_X); - } - - return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); - } - - // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetX - = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::SplitKStride_X; - - return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); - } - ROCWMMA_DEVICE static inline auto debug() - { - if(threadIdx.x == 0 && threadIdx.y == 0) - { - printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", - (uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); - - printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " - "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", - (uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y, - (uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y, - (uint32_t)Traits::VWStride_X, - (uint32_t)Traits::VWStride_Y); - } - if(threadIdx.x <= 63 && threadIdx.y == 0) - { - printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", - threadIdx.x, - get<0>(baseOffset()), - get<1>(baseOffset())); - } - } - }; - - template // # of splits - struct ColOrthoInt - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, - - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), - - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, - - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, - - BlockKStride_X = 1u, - BlockKStride_Y = 0u, - - VWStride_X = 0u, - VWStride_Y = VectorWidth, - - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = DimPerThread / BlockKStride_X, - VWSegs = KPerThread / VWStride_Y, - }; - - // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); - - // Check VectorWidth validity - static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); - static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, - "KPerThread not a multiple of VectorWidth"); - - // Check SplitK validity - static_assert(BlockK >= SplitK, "Invalid SplitK"); - static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); - - // Check MfmaDim validity - static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); - static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowOrthoInt; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); - } - - // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetX = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetX = (int32_t)Traits::VWStride_X; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::SplitKSegs > 1)) - { - // "Reset" cycle - VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::SplitKSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::SplitKStride_X); - } - - return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); - } - - // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetX - = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::SplitKStride_X; - - return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); - } - - ROCWMMA_DEVICE static inline auto debug() - { - // if(threadIdx.x == 0 && threadIdx.y == 0) - // { - // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", - // (uint32_t)Traits::SplitKSegs, - // (uint32_t)Traits::BlockKSegs, - // (uint32_t)Traits::VWSegs); - - // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", - // (uint32_t)Traits::SplitKStride_X, - // (uint32_t)Traits::SplitKStride_Y, - // (uint32_t)Traits::BlockKStride_X, - // (uint32_t)Traits::BlockKStride_Y, - // (uint32_t)Traits::VWStride_X, - // (uint32_t)Traits::VWStride_Y); - - // } - // if(threadIdx.x <= 63 && threadIdx.y == 0) - // { - // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); - // } - } - }; - - template - struct RowInlineInt - { - // RowInlineInt is orthogonal to ColInlineInt, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout = ColInlineInt; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } - - ROCWMMA_DEVICE static inline auto debug() - { - Traits::OrthoLayout::debug(); - } - }; - - template - struct RowOrthoInt - { - // RowOrthoInt is orthogonal to ColOrthoInt, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout = ColOrthoInt; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } - - ROCWMMA_DEVICE static inline auto debug() {} - }; } // namespace MatrixLayout - //////////////////////////////////// - /// MatrixLayout specializations /// - //////////////////////////////////// - - // Matrix layout matching test criteria are if all parameters match, with some flexibility in VectorWidth. - template - class MatrixLayout> - struct is_layout_same< - MatrixLayout, - MatrixLayout> - : public integral_constant - { - }; - - // Matrix layout transpose test with flexibility in the VectorWidth. - // Transposed matrix layouts swap matrix space rows / cols. - template - struct is_layout_transpose, - MatrixLayout::template RowOrthoInt> - : public integral_constant - { - }; - - template - struct is_layout_transpose, - MatrixLayout::template ColOrthoInt> - : public integral_constant - { - }; - - template - struct is_layout_transpose, - MatrixLayout::template RowInlineInt> - : public integral_constant - { - }; - - template - struct is_layout_transpose, - MatrixLayout::template ColInlineInt> - : public integral_constant - { - }; - - // Matrix space transpose guide: Swap rows / cols - // VW stays consistent. - template - struct layout_transpose> - { - using type = MatrixLayout::template RowOrthoInt; - }; - - template - struct layout_transpose> - { - using type = MatrixLayout::template ColOrthoInt; - }; - - template - struct layout_transpose> - { - using type = MatrixLayout::template RowInlineInt; - }; - - template - struct layout_transpose> - { - using type = MatrixLayout::template ColInlineInt; - }; - - /////////////////////////////////////// - /// Register layout specializations /// - /////////////////////////////////////// - - // Register layouts are the same if all test parameters match, with some flexibility in VectorWidth. - template - class MatrixLayout> - struct is_layout_same< - RegisterLayout::template Storage< - MatrixLayout>, - RegisterLayout::template Storage< - MatrixLayout>> - : public integral_constant - { - }; - - // ColOrthoInt and RowOrthoInt layouts are already in mma input format for mma sized BlockDim (16 or 32) - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - // TODO: necessary? - // In-register layouts for transposed RowOrthoInt / ColOrthoInt and RowInlineInt / ColInline are technically 'the same' - // for each thread, even though the data interpretation is different (e.g., row elements vs col elements). - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_same< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - // ColInlineInt and RowInlineInt layouts are transposed to mma input format for mma sized BlockDim (16 or 32) - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template MmaInput> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template MmaInput, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - // In-register layouts for (ColInlineInt / ColOrthoInt) and (RowInlineInt / RowOrthoInt) are the orthogonal register transposes. - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - // In-register layouts for (RowOrthoInt / ColInlineInt) and (ColOrthoInt / RowInlineInt) are the orthogonal register transposes. - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - - template - struct is_layout_transpose< - RegisterLayout::template Storage>, - RegisterLayout::template Storage>> - : public integral_constant - { - }; - } // namespace rocwmma #endif // ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp b/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp new file mode 100644 index 00000000..d87999b3 --- /dev/null +++ b/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp @@ -0,0 +1,445 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP +#define ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP + +#include "config.hpp" +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + // Common helpers for supported traits + namespace LayoutTraits_impl + { + // Based on the current config, determine the compatibility of the mma dimension + constexpr static inline bool testSupportedMmaDim(uint32_t MmaDim) + { + return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && MmaDim == 16u) + || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && (MmaDim == 16u || MmaDim == 32u)); + } + + // VW can be changed from vw0 to vw1 as long as they have the same maxVW, and that maxVW + // is a multiple of both vw values + constexpr static inline bool testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) + { + return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); + } + + // Reference regular layouts + using MatrixLayout::ColInlineVW; + using MatrixLayout::ColOrthoVW; + using MatrixLayout::RowInlineVW; + using MatrixLayout::RowOrthoVW; + + // Reference interleaved layouts + using MatrixLayout::ColInlineInt; + using MatrixLayout::ColOrthoInt; + using MatrixLayout::RowInlineInt; + using MatrixLayout::RowOrthoInt; + + // NOTE: MatrixLayout assumptions + // When determining MatrixLayout traits, there are several strong assumptions. + // 1. Regarding same-ness: MatrixLayouts must match, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Same) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // + // 2. Regarding orthogonality: for all Col* layouts, their Row* + // orthogonal counterparts are implemented by row / col coordinate swaps. + // This is valid as long as we have some fixed parameters, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Orthogonal) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // This defines the need for MatrixLayout classifiers based upon: + // - ColOrtho / RowOrtho + // - ColInline / RowInline + // - Non-interleave / non-interleaved + // + // Following the above traits, we can build more complicated traits such as + // is_same, is_orthogonal and orthogonal_layout. + + // Classifier for ColOrtho MatrixLayout + template + struct is_col_ortho : public false_type + { + }; + + template + struct is_col_ortho> + : public true_type + { + }; + + template + struct is_col_ortho> + : public true_type + { + }; + + // Classifier for RowOrtho MatrixLayout + template + struct is_row_ortho : public false_type + { + }; + + template + struct is_row_ortho> + : public true_type + { + }; + + template + struct is_row_ortho> + : public true_type + { + }; + + // Classifier for ColInline MatrixLayout + template + struct is_col_inline : public false_type + { + }; + + template + struct is_col_inline> + : public true_type + { + }; + + template + struct is_col_Inline> + : public true_type + { + }; + + // Classifier for RowInline MatrixLayout + template + struct is_row_inline : public false_type + { + }; + + template + struct is_row_inline> + : public true_type + { + }; + + template + struct is_row_inline> + : public true_type + { + }; + + // Classifier for interleaved layout + template + struct is_interleaved : public false_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + // Convenience evaluators + template + constexpr static bool is_col_ortho_v = is_col_ortho::value; + + template + constexpr static bool is_row_ortho_v = is_row_ortho::value; + + template + constexpr static bool is_col_inline_v = is_col_inline::value; + + template + constexpr static bool is_row_inline_v = is_row_inline::value; + + template + constexpr static bool is_interleaved_v = is_interleaved::value; + + // When comparing one MatrixLayout to another, we need a way to check parameter compatibility. + template + struct is_compatible_params : public false_type + { + }; + + // Non-interleaved matrix layouts require that BlockDim, BlockK, MaxVW are matching. + // VectorWidth values must satisfy criterion in testSupportedVW(). + template + class MatrixLayoutLhs, + template + class MatrixLayoutRhs> + struct is_compatible_params< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t< + !is_interleaved_v< + MatrixLayoutLhs> + && !is_interleaved_v< + MatrixLayoutRhs> + && testSupportedVW(MaxVectorWidth, VectorWidthLhs, VectorWidthRhs)>> + : public true_type + { + }; + + // Interleaved matrix layouts require that BlockDim, BlockK, MmaDim, SplitK are matching. + // MmaDim values must satisfy criterion in testSupportedMmaDim(). + template + class MatrixLayoutLhs, + template + class MatrixLayoutRhs> + struct is_compatible_params< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t< + is_interleaved_v> + && is_interleaved_v> + && testSupportedMmaDim(MmaDim)>> : public true_type + { + }; + + // Convenience evaluator + template + constexpr static bool is_compatible_params_v + = is_compatible_params::value; + + // Classifier to test same-ness, implements criterion #1 from above: + template + struct is_layout_same< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t<((is_col_ortho_v && is_col_ortho_v) + || (is_row_ortho_v && is_row_ortho_v) + || (is_col_inline_v && is_col_inline_v) + || (is_row_inline_v && is_row_inline_v)) + && is_compatible_params_v>> + : public true_type + { + }; + + // Classifier to test orthogonality, implements criterion #2 from above: + template + struct is_layout_orthogonal< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t<((is_col_ortho_v && is_row_ortho_v) + || (is_row_ortho_v && is_col_ortho_v) + || (is_col_inline_v && is_row_inline_v) + || (is_row_inline_v && is_col_inline_v)) + && is_compatible_params_v>> + : public true_type + { + }; + + // Matrix space transpose guide: Swap rows / cols + // VW stays consistent. + template + struct orthogonal_layout> + { + using type = RowOrthoVW; + }; + + template + struct orthogonal_layout> + { + using type = ColOrthoVW; + }; + + template + struct orthogonal_layout> + { + using type = RowInlineVW; + }; + + template + struct orthogonal_layout> + { + using type = ColInlineVW; + }; + + // Orthogonal guide for interleaved layouts + template + struct orthogonal_layout> + { + using type = RowOrthoInt; + }; + + template + struct orthogonal_layout> + { + using type = ColOrthoInt; + }; + + template + struct orthogonal_layout> + { + using type = RowInlineInt; + }; + + template + struct orthogonal_layout> + { + using type = ColInlineInt; + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#endif // ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_impl.hpp deleted file mode 100644 index 9118f08f..00000000 --- a/library/include/rocwmma/internal/layout/register_layout_impl.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_REGISTER_LAYOUT_IMPL_HPP -#define ROCWMMA_REGISTER_LAYOUT_IMPL_HPP - -#include "layout.hpp" -#include "layout_traits.hpp" -#include "utility/type_traits.hpp" - -namespace rocwmma -{ - // Use generic MatrixLayout transpose rules to guide the register layout transpose suggestion - template - struct layout_transpose> - { - using type = RegisterLayout::template Storage>; - }; - -} // namespace rocwmma - -#endif // ROCWMMA_REGISTER_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_traits.hpp b/library/include/rocwmma/internal/layout/register_layout_traits.hpp new file mode 100644 index 00000000..6c0593ba --- /dev/null +++ b/library/include/rocwmma/internal/layout/register_layout_traits.hpp @@ -0,0 +1,352 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP +#define ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" +#include "utility/type_traits.hpp" + +namespace rocwmma +{ + namespace LayoutTraits_impl + { + + // NOTE: RegisterLayout assumptions + // When determining RegisterLayout traits, there are several strong assumptions. + // 1. Regarding same-ness: + // - Storage match if MatrixLayouts match, given fixed params. + // - Storage match if MatrixLayouts are either both *Ortho or both *Inline + // orientations. Register thread mapping is the same while swapping the underlying + // meaning of rows for cols (e.g., implicit transpose). + // - Storage<*Ortho> layouts are suitable MmaInputs while Storage<*Inline> layouts are not. + // Given appropriate MmaDim, it is assumed MmaInput layouts are mapped to mma hardware + // requirements. + // _________________________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Same) | Required Fixed Params | + // | ------------------------------------------------------------------------------- | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | MmaInput | BlockDim == MmaDim | + // | MmaInput | Storage | BlockDim == MmaDim | + // | Storage | MmaInput | BlockDim == MmaDim | + // | MmaInput | Storage | BlockDim == MmaDim | + // | ------------------------------------------------------------------------------- | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | MmaInput | MmaDim | + // | MmaInput | Storage | MmaDim | + // | Storage | MmaInput | MmaDim | + // | MmaInput | Storage | MmaDim | + // | ------------------------------------------------------------------------------- | + // + // 2. Regarding orthogonality: + // - Storages are considered orthogonal if one MatrixLayout is an + // *Ortho layout and the other is an *Inline layout, or vice versa. + // - Since MmaInput layouts are same as Storage layouts with appropriate + // MmaDim, MmaInput is also orthogonal to Storage layouts. + // _______________________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Required Fixed Params | + // | | (Transposed) | | + // | ----------------------------------------------------------------------------- | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | + // | Storage | MmaInput | BlockDim == MmaDim | + // | MmaInput | Storage | BlockDim == MmaDim | + // | Storage | MmaInput | BlockDim == MmaDim | + // | MmaInput | Storage | BlockDim == MmaDim | + // | ----------------------------------------------------------------------------- | + // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | + // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | + // | Storage | MmaInput | MmaDim | + // | MmaInput | Storage| MmaDim | + // | Storage | MmaInput | MmaDim | + // | MmaInput | Storage| MmaDim | + // | ----------------------------------------------------------------------------- | + + using RegisterLayout::MmaAcc; + using RegisterLayout::MmaInput; + using RegisterLayout::Storage; + + // Classifier for storage of col ortho + template + struct is_storage_col_ortho : public false_type + { + }; + + template + struct is_storage_col_ortho enable_if_t>> + : public true_type + { + }; + + // Classifier for storage of row ortho + template + struct is_storage_row_ortho : public false_type + { + }; + + template + struct is_storage_row_ortho, + enable_if_t>> : public true_type + { + }; + + // Classifier for storage of col inline + template + struct is_storage_col_inline : public false_type + { + }; + + template + struct is_storage_col_inline, + enable_if_t>> : public true_type + { + }; + + // Classifier for storage of row inline + template + struct is_storage_row_inline : public false_type + { + }; + + template + struct is_storage_row_inline, + enable_if_t>> : public true_type + { + }; + + // Classifier for mma inputs + template + struct is_mma_input : public false_type + { + }; + + template + struct is_mma_input> : public true_type + { + }; + + // Convenience evaluators + template + constexpr static bool is_storage_col_ortho_v = is_storage_col_ortho::value; + + template + constexpr static bool is_storage_row_ortho_v = is_storage_row_ortho::value; + + template + constexpr static bool is_storage_col_inline_v + = is_storage_col_inline::value; + + template + constexpr static bool is_storage_row_inline_v + = is_storage_row_inline::value; + + template + constexpr static bool is_mma_input_v = is_mma_input::value; + + // Compatibility for Storage, passthrough to MatrixLayout compatibility. + template + struct is_compatible_params, Storage, void> + : public is_compatible_params + { + }; + + // Non-interleaved MmaInput layouts require a valid MmaDim (as BlockDim). + // MmaDim values must hold to certain criterion in testSupportedMmaDim(). + template + class MatrixLayout> + struct is_compatible_params< + Storage>, + MmaInput, + enable_if_t> + && testSupportedMmaDim(BlockDim)>> : public true_type + { + }; + + template + class MatrixLayout> + struct is_compatible_params< + MmaInput, + Storage>, + enable_if_t> + && testSupportedMmaDim(BlockDim)>> : public true_type + { + }; + + // Interleaved MmaInput layouts require a valid MmaDim. + // MmaDim values must hold to certain criterion in testSupportedMmaDim(). + template + class MatrixLayout> + struct is_compatible_params< + Storage>, + MmaInput, + enable_if_t> + && testSupportedMmaDim(MmaSize)>> : public true_type + { + }; + + template + class MatrixLayout> + struct is_compatible_params< + MmaInput, + Storage>, + enable_if_t> + && testSupportedMmaDim(MmaSize)>> : public true_type + { + }; + + // Checks if both RegisterLayout storages are the same with compatible params + template + struct is_layout_same< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t< + // Check for same in-register layouts + ((is_storage_col_ortho_v + && is_storage_col_ortho_v) + || (is_storage_row_ortho_v + && is_storage_row_ortho_v) + || (is_storage_col_inline_v + && is_storage_col_inline_v) + || (is_storage_row_inline_v + && is_storage_row_inline_v) + // Check for in-register implicit transposes. These have the same register layouts, but swap meaning + // for rows / cols. + || (is_storage_col_ortho_v + && is_storage_row_ortho_v) + || (is_storage_row_ortho_v + && is_storage_col_ortho_v) + || (is_storage_col_inline_v + && is_storage_row_inline_v) + || (is_storage_row_inline_v + && is_storage_col_inline_v) + // Check for mma input compatibility + || (is_storage_col_ortho_v && is_mma_input_v) + || (is_mma_input_v && is_storage_col_ortho_v) + || (is_storage_row_ortho_v && is_mma_input_v) + || (is_mma_input_v + && is_storage_row_ortho_v)) + && is_compatible_params_v>> : public true_type + { + }; + + // Checks if RegisterLayouts are transposed with compatible params + template + struct is_layout_orthogonal< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<( // Orthogonality in same orientation (e.g., col / row) + (is_storage_col_ortho_v + && is_storage_col_inline_v) + || (is_storage_col_inline_v + && is_storage_col_ortho_v) + || (is_storage_row_ortho_v + && is_storage_row_inline_v) + || (is_storage_row_inline_v + && is_storage_row_ortho_v) + // Orthogonality in opposite orientation (e.g., col vs row) + || (is_storage_col_ortho_v + && is_storage_row_inline_v) + || (is_storage_row_inline_v + && is_storage_col_ortho_v) + || (is_storage_col_inline_v + && is_storage_row_ortho_v) + || (is_storage_row_ortho_v + && is_storage_col_inline_v) + // Mma orthogonality + || (is_storage_col_inline_v + && is_mma_input_v) + || (is_mma_input_v + && is_storage_col_inline_v) + || (is_storage_row_inline_v + && is_mma_input_v) + || (is_mma_input_v + && is_storage_row_inline_v)) + && is_compatible_params_v>> + : public true_type + { + }; + + // Use generic MatrixLayout orthogonality rules to guide the register layout transpose suggestion + template + struct orthogonal_layout> + { + using type = Storage::type>; + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#endif // ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP From e15f5b0d3df779e90172af064d2317beab83f637 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 12 Sep 2024 20:23:38 +0000 Subject: [PATCH 05/36] Remove unused file --- .../layout/matrix_layout_interleaved_impl.hpp | 43 ------------------- 1 file changed, 43 deletions(-) delete mode 100644 library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp diff --git a/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp deleted file mode 100644 index e866017b..00000000 --- a/library/include/rocwmma/internal/layout/matrix_layout_interleaved_impl.hpp +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP -#define ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP - -#include "layout.hpp" -#include "layout_traits.hpp" - -namespace rocwmma -{ - - // Implementations for the interleaved MatrixLayout classes - namespace MatrixLayout - { - - } // namespace MatrixLayout - -} // namespace rocwmma - -#endif // ROCWMMA_MATRIX_LAYOUT_INTERLEAVED_IMPL_HPP From c27e4e1739d814222e0bb690e2ad027764b245fc Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 26 Sep 2024 16:49:47 +0000 Subject: [PATCH 06/36] Refactor layout traits --- .../include/rocwmma/internal/coop_load.hpp | 1 - .../include/rocwmma/internal/coop_store.hpp | 1 - .../include/rocwmma/internal/io_layout.hpp | 2 +- .../rocwmma/internal/layout/layout.hpp | 4 +- .../internal/layout/layout_profile.hpp | 64 +-- .../rocwmma/internal/layout/layout_traits.hpp | 5 +- .../internal/layout/layout_traits_impl.hpp | 2 +- .../internal/layout/matrix_layout_impl.hpp | 61 ++- .../internal/layout/matrix_layout_traits.hpp | 307 ++++++++----- .../layout/register_layout_traits.hpp | 419 ++++++++++-------- .../layout/register_layout_transforms.hpp | 130 ++++++ .../include/rocwmma/internal/opaque_load.hpp | 1 - .../include/rocwmma/internal/opaque_store.hpp | 6 +- .../include/rocwmma/internal/transforms.hpp | 179 -------- .../rocwmma/internal/utility/type_traits.hpp | 32 +- library/include/rocwmma/rocwmma_impl.hpp | 10 +- .../rocwmma/rocwmma_transforms_impl.hpp | 28 +- 17 files changed, 683 insertions(+), 569 deletions(-) create mode 100644 library/include/rocwmma/internal/layout/register_layout_transforms.hpp diff --git a/library/include/rocwmma/internal/coop_load.hpp b/library/include/rocwmma/internal/coop_load.hpp index 11bb445d..390ba3f7 100644 --- a/library/include/rocwmma/internal/coop_load.hpp +++ b/library/include/rocwmma/internal/coop_load.hpp @@ -27,7 +27,6 @@ #define ROCWMMA_COOP_LOAD_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "opaque_load.hpp" #include "types.hpp" #include "utils.hpp" diff --git a/library/include/rocwmma/internal/coop_store.hpp b/library/include/rocwmma/internal/coop_store.hpp index 9f0f22b7..5086e057 100644 --- a/library/include/rocwmma/internal/coop_store.hpp +++ b/library/include/rocwmma/internal/coop_store.hpp @@ -27,7 +27,6 @@ #define ROCWMMA_COOP_STORE_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "opaque_store.hpp" #include "types.hpp" #include "utils.hpp" diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 1e6a0555..ebcea724 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -28,7 +28,7 @@ #include "api_fwd.hpp" #include "constants.hpp" -#include "layout.hpp" +#include "layout/layout_profile.hpp" #include "types.hpp" namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index a50d89cf..a860b61c 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -26,8 +26,8 @@ #ifndef ROCWMMA_LAYOUT_HPP #define ROCWMMA_LAYOUT_HPP -#include "api_fwd.hpp" -#include "mapping_util.hpp" +#include "../api_fwd.hpp" +#include "../mapping_util.hpp" namespace rocwmma { diff --git a/library/include/rocwmma/internal/layout/layout_profile.hpp b/library/include/rocwmma/internal/layout/layout_profile.hpp index 4e243061..cebab878 100644 --- a/library/include/rocwmma/internal/layout/layout_profile.hpp +++ b/library/include/rocwmma/internal/layout/layout_profile.hpp @@ -218,29 +218,16 @@ namespace rocwmma uint32_t BlockK, typename DataT, typename DataLayoutT, - uint32_t VectorWidth, - uint32_t MaxVectorWidth = VectorWidth, - uint32_t MfmaDim = 16u, - uint32_t SplitK = 1u> + uint32_t MfmaDim = 16u, + uint32_t SplitK = 1u> struct ColInt { // Layouts using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t, - MatrixLayout::ColInlineInt, - MatrixLayout::ColOrthoInt>; + using MatrixLayout = conditional_t< + is_same_v, + MatrixLayout::ColInlineInt, + MatrixLayout::ColOrthoInt>; using RegisterLayout = RegisterLayout::Storage; @@ -250,9 +237,10 @@ namespace rocwmma // Sanity checks // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); + // TODO: fix + // static_assert( + // !(is_same_v && (MaxVectorWidth > BlockK)), + // "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; // Row is a layout profile that has the following properties: @@ -270,29 +258,16 @@ namespace rocwmma uint32_t BlockK, typename DataT, typename DataLayoutT, - uint32_t VectorWidth, - uint32_t MaxVectorWidth = VectorWidth, - uint32_t MfmaDim = 16u, - uint32_t SplitK = 1u> + uint32_t MfmaDim = 16u, + uint32_t SplitK = 1u> struct RowInt { // Layouts using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t, - MatrixLayout::RowInlineInt, - MatrixLayout::RowOrthoInt>; + using MatrixLayout = conditional_t< + is_same_v, + MatrixLayout::RowInlineInt, + MatrixLayout::RowOrthoInt>; using RegisterLayout = RegisterLayout::Storage; @@ -302,9 +277,10 @@ namespace rocwmma // Sanity checks // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); + // TODO: fix + // static_assert( + // !(is_same_v && (MaxVectorWidth > BlockK)), + // "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); }; } // namespace LayoutProfile diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp index 60002937..554952ea 100644 --- a/library/include/rocwmma/internal/layout/layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -26,10 +26,13 @@ #ifndef ROCWMMA_LAYOUT_TRAITS_HPP #define ROCWMMA_LAYOUT_TRAITS_HPP -#include "data_layout_traits.hpp" +// Need strict inclusion order here +// clang-format off #include "layout_traits_impl.hpp" +#include "data_layout_traits.hpp" #include "matrix_layout_traits.hpp" #include "register_layout_traits.hpp" +// clang-format on namespace rocwmma { diff --git a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp index b7647aa1..5e84fd46 100644 --- a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp @@ -26,7 +26,7 @@ #ifndef ROCWMMA_LAYOUT_TRAITS_IMPL_HPP #define ROCWMMA_LAYOUT_TRAITS_IMPL_HPP -#include "utility/type_traits.hpp" +#include "../utility/type_traits.hpp" namespace rocwmma { diff --git a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp index 8710c8ba..d2d1c84e 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -930,45 +930,76 @@ namespace rocwmma template struct OrthoImpl { - struct Traits - { - using OrthoLayout = orthogonal_layout_t; - }; - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename auto baseOffset() + ROCWMMA_DEVICE static inline auto baseOffset() { - return swap(Traits::OrthoLayout::baseOffset()); + return swap(MatrixLayout::baseOffset()); } ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - return Traits::OrthoLayout::strideCounts(); + return MatrixLayout::strideCounts(); } ROCWMMA_DEVICE constexpr static inline auto strides() { - auto t = Traits::OrthoLayout::strides(); + auto t = MatrixLayout::strides(); + // TODO: use apply + //apply([](auto const& v){ return swap(v); }); return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); + return swap(MatrixLayout::incrementalOffset(iteration)); } ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); + return swap(MatrixLayout::cumulativeOffset(iteration)); } ROCWMMA_DEVICE static inline auto debug() {} }; - using RowOrthoVW = OrthoImpl; - using RowInlineVW = OrthoImpl; - using RowOrthoInt = OrthoImpl; - using RowInlineInt = OrthoImpl; + template + struct RowOrthoVW + : public OrthoImpl> + { + }; + + template + struct RowInlineVW + : public OrthoImpl> + { + }; + + template // # of splits + struct RowOrthoInt : public OrthoImpl> + { + }; + + template // # of splits + struct RowInlineInt + : public OrthoImpl> + { + }; } // namespace MatrixLayout diff --git a/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp b/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp index d87999b3..ad1e42f1 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp @@ -26,7 +26,7 @@ #ifndef ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP #define ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP -#include "config.hpp" +#include "../config.hpp" #include "layout.hpp" #include "layout_traits.hpp" @@ -61,50 +61,14 @@ namespace rocwmma using MatrixLayout::RowInlineInt; using MatrixLayout::RowOrthoInt; - // NOTE: MatrixLayout assumptions - // When determining MatrixLayout traits, there are several strong assumptions. - // 1. Regarding same-ness: MatrixLayouts must match, as defined below: - // ____________________________________________________________________ - // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | - // | | (Same) | Required Fixed Params | - // | ------------------------------------------------------------------ | - // | ColOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | - // | ColInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | - // | RowOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | - // | RowInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | - // | ------------------------------------------------------------------ | - // | ColOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | - // | ColInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | - // | RowOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | - // | RowInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | - // -------------------------------------------------------------------- - // - // 2. Regarding orthogonality: for all Col* layouts, their Row* - // orthogonal counterparts are implemented by row / col coordinate swaps. - // This is valid as long as we have some fixed parameters, as defined below: - // ____________________________________________________________________ - // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | - // | | (Orthogonal) | Required Fixed Params | - // | ------------------------------------------------------------------ | - // | ColOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | - // | ColInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | - // | RowOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | - // | RowInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | - // | ------------------------------------------------------------------ | - // | ColOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | - // | ColInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | - // | RowOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | - // | RowInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | - // -------------------------------------------------------------------- - // This defines the need for MatrixLayout classifiers based upon: - // - ColOrtho / RowOrtho - // - ColInline / RowInline - // - Non-interleave / non-interleaved - // - // Following the above traits, we can build more complicated traits such as - // is_same, is_orthogonal and orthogonal_layout. - - // Classifier for ColOrtho MatrixLayout + // Start to build a basic set of meta-data classifiers. + // We will be interested in knowing things about our matrix layouts: + // - is_col_ortho + // - is_row_ortho + // - is_col_inline + // - is_row_inline + // - is_interleaved + // - is_matrix_layout template struct is_col_ortho : public false_type { @@ -130,7 +94,6 @@ namespace rocwmma { }; - // Classifier for RowOrtho MatrixLayout template struct is_row_ortho : public false_type { @@ -156,7 +119,6 @@ namespace rocwmma { }; - // Classifier for ColInline MatrixLayout template struct is_col_inline : public false_type { @@ -177,12 +139,11 @@ namespace rocwmma typename DataT, uint32_t MfmaDim, uint32_t SplitK> - struct is_col_Inline> + struct is_col_inline> : public true_type { }; - // Classifier for RowInline MatrixLayout template struct is_row_inline : public false_type { @@ -208,7 +169,6 @@ namespace rocwmma { }; - // Classifier for interleaved layout template struct is_interleaved : public false_type { @@ -270,74 +230,206 @@ namespace rocwmma template constexpr static bool is_interleaved_v = is_interleaved::value; - // When comparing one MatrixLayout to another, we need a way to check parameter compatibility. - template - struct is_compatible_params : public false_type + template + struct is_matrix_layout + : public integral_constant || is_col_inline_v + || is_row_ortho_v + || is_row_inline_v> { }; - // Non-interleaved matrix layouts require that BlockDim, BlockK, MaxVW are matching. - // VectorWidth values must satisfy criterion in testSupportedVW(). - template + constexpr static bool is_matrix_layout_v = is_matrix_layout::value; + + // Next we can build a set of base trait accessors for the MatrixLayout. These + // will be reflective of the input template params of the MatrixLayout instance. + + template + struct matrix_layout_base_traits; + + // Represent non-interleaved MatrixLayout instances + template - class MatrixLayoutLhs, + class MatrixLayout> + struct matrix_layout_base_traits< + MatrixLayout, + enable_if_t> + && !is_interleaved_v>>> + { + constexpr static uint32_t BlockDim = LayoutBlockDim; + constexpr static uint32_t KDim = LayoutBlockK; + using DataT = LayoutDataT; + constexpr static uint32_t VectorWidth = LayoutVectorWidth; + constexpr static uint32_t MaxVectorWidth = LayoutMaxVectorWidth; + }; + + // Represent interleaved MatrixLayout instances + template - class MatrixLayoutRhs> - struct is_compatible_params< - MatrixLayoutLhs, - MatrixLayoutRhs, - enable_if_t< - !is_interleaved_v< - MatrixLayoutLhs> - && !is_interleaved_v< - MatrixLayoutRhs> - && testSupportedVW(MaxVectorWidth, VectorWidthLhs, VectorWidthRhs)>> - : public true_type + class MatrixLayout> + struct matrix_layout_base_traits< + MatrixLayout, + enable_if_t> + && is_interleaved_v>>> + { + constexpr static uint32_t BlockDim = LayoutBlockDim; + constexpr static uint32_t KDim = LayoutBlockK; + using DataT = LayoutDataT; + constexpr static uint32_t MmaDim = LayoutMmaDim; + constexpr static uint32_t SplitK = LayoutSplitK; + }; + + // Combine base instance traits with specific layout classifiers + template + struct matrix_layout_traits : public matrix_layout_base_traits + { + constexpr static bool is_col_ortho = is_col_ortho_v; + constexpr static bool is_col_inline = is_col_inline_v; + constexpr static bool is_row_ortho = is_row_ortho_v; + constexpr static bool is_row_inline = is_row_inline_v; + constexpr static bool is_interleaved = is_interleaved_v; + constexpr static bool is_matrix_layout = is_matrix_layout_v; + }; + + // NOTE: MatrixLayout assumptions + // When determining MatrixLayout traits, there are several strong assumptions. + // 1. Regarding same-ness: MatrixLayouts must match, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Same) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // + // 2. Regarding orthogonality: for all Col* layouts, their Row* + // orthogonal counterparts are implemented by row / col coordinate swaps. + // This is valid as long as we have some fixed parameters, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Orthogonal) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // This defines the need for MatrixLayout classifiers based upon: + // - ColOrtho / RowOrtho + // - ColInline / RowInline + // - Non-interleave / non-interleaved + // + // Following the above traits, we can build more complicated traits such as + // is_same, is_orthogonal and orthogonal_layout. + + // When comparing one MatrixLayout to another, we need a way to check parameter compatibility. + template + struct is_compatible_matrix_params : public false_type { }; - // Interleaved matrix layouts require that BlockDim, BlockK, MmaDim, SplitK are matching. - // MmaDim values must satisfy criterion in testSupportedMmaDim(). - template - class MatrixLayoutLhs, - template - class MatrixLayoutRhs> - struct is_compatible_params< - MatrixLayoutLhs, - MatrixLayoutRhs, - enable_if_t< - is_interleaved_v> - && is_interleaved_v> - && testSupportedMmaDim(MmaDim)>> : public true_type +// Keeps things a bit more tidy. Quick access to matrix layout traits. +#define mat_traits_lhs matrix_layout_traits +#define mat_traits_rhs matrix_layout_traits + + // Non-interleaved matrix layout compatibility requires: + // 1. Must have fixed: BlockDim, KDim, MaxVectorWidth + // 2. VectorWidths must satisfy criterion in testSupportedVW(). + template + struct is_compatible_matrix_params< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t<(!mat_traits_lhs::is_interleaved && !mat_traits_rhs::is_interleaved)>> + : public integral_constant + { + }; + + // Interleaved matrix layout compatibility requires: + // 1. Must have fixed BlockDim, BlockK, MmaDim, SplitK + // 2. MmaDim values must satisfy criterion in testSupportedMmaDim(). + template + struct is_compatible_matrix_params< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t<(mat_traits_lhs::is_interleaved && mat_traits_rhs::is_interleaved)>> + : public integral_constant { }; // Convenience evaluator template - constexpr static bool is_compatible_params_v - = is_compatible_params::value; + constexpr static bool is_compatible_matrix_params_v + = is_compatible_matrix_params::value; + + // Now to implement the interfaces for is_layout_same and is_layout_orthogonal, + // with MatrixLayout types. // Classifier to test same-ness, implements criterion #1 from above: template struct is_layout_same< MatrixLayoutLhs, MatrixLayoutRhs, - enable_if_t<((is_col_ortho_v && is_col_ortho_v) - || (is_row_ortho_v && is_row_ortho_v) - || (is_col_inline_v && is_col_inline_v) - || (is_row_inline_v && is_row_inline_v)) - && is_compatible_params_v>> - : public true_type + enable_if_t> + : public integral_constant< + bool, + ((mat_traits_lhs::is_col_ortho_v && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_row_ortho_v && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_col_inline_v && mat_traits_rhs::is_col_inline) + || (mat_traits_lhs::is_row_inline_v && mat_traits_rhs::is_row_inline)) + && is_compatible_matrix_params_v> { }; @@ -346,15 +438,20 @@ namespace rocwmma struct is_layout_orthogonal< MatrixLayoutLhs, MatrixLayoutRhs, - enable_if_t<((is_col_ortho_v && is_row_ortho_v) - || (is_row_ortho_v && is_col_ortho_v) - || (is_col_inline_v && is_row_inline_v) - || (is_row_inline_v && is_col_inline_v)) - && is_compatible_params_v>> - : public true_type + enable_if_t> + : public integral_constant< + bool, + ((mat_traits_lhs::is_col_ortho_v && mat_traits_rhs::is_row_ortho_v) + || (mat_traits_lhs::is_row_ortho_v && mat_traits_rhs::is_col_ortho_v) + || (mat_traits_lhs::is_col_inline_v && mat_traits_rhs::is_row_inline_v) + || (mat_traits_lhs::is_row_inline_v && mat_traits_rhs::is_col_inline_v)) + && is_compatible_matrix_params_v> { }; +#undef mat_traits_lhs +#undef mat_traits_rhs + // Matrix space transpose guide: Swap rows / cols // VW stays consistent. template + struct is_register_layout : public false_type + { + }; + + template + struct is_register_layout> : public is_matrix_layout + { + }; + + template + struct is_register_layout> : public true_type + { + }; + + template + struct is_register_layout> : public true_type + { + }; + + template + struct is_storage_layout : public false_type + { + }; + + template + struct is_storage_layout> : public is_matrix_layout + { + }; + + template + struct is_mma_input_layout : public false_type + { + }; + + template + struct is_mma_input_layout> : public true_type + { + }; + + template + struct is_mma_acc_layout : public false_type + { + }; + + template + struct is_mma_acc_layout> : public true_type + { + }; + + // Convenience evaluators + template + constexpr inline static bool is_register_layout_v + = is_register_layout::value; + + template + constexpr inline static bool is_storage_layout_v = is_storage_layout::value; + + template + constexpr inline static bool is_mma_input_layout_v + = is_mma_input_layout::value; + + template + constexpr inline static bool is_mma_acc_layout_v = is_mma_acc_layout::value; + + // Next we can build a set of base trait accessors for the RegisterLayout. These + // will be reflective of the input template params of the RegisterLayout instance. + template + struct register_layout_base_traits; + + template + struct register_layout_base_traits> + { + using MatrixLayout = MatrixLayoutInternal; + }; + + template + struct register_layout_base_traits> + { + constexpr static uint32_t MmaDim = LayoutMmaDim; + using MatrixLayout = void; + }; + + template + struct register_layout_base_traits> + { + constexpr static uint32_t MmaDim = LayoutMmaDim; + using MatrixLayout = void; + }; + + // Combine base instance traits with specific layout classifiers + template + struct register_layout_traits : public register_layout_base_traits + { + constexpr static bool is_register_layout = is_register_layout_v; + constexpr static bool is_storage_layout = is_storage_layout_v; + constexpr static bool is_mma_input_layout = is_mma_input_layout_v; + constexpr static bool is_mma_acc_layout = is_mma_acc_layout_v; + }; // NOTE: RegisterLayout assumptions // When determining RegisterLayout traits, there are several strong assumptions. @@ -112,192 +221,127 @@ namespace rocwmma // | MmaInput | Storage| MmaDim | // | ----------------------------------------------------------------------------- | - using RegisterLayout::MmaAcc; - using RegisterLayout::MmaInput; - using RegisterLayout::Storage; +// Keeps things a bit more tidy. Quick access to register layout traits. +#define reg_traits_lhs register_layout_traits +#define reg_traits_rhs register_layout_traits - // Classifier for storage of col ortho - template - struct is_storage_col_ortho : public false_type - { - }; +// Quick access to matrix layout traits, that are embedded in the register layout traits. +#define mat_traits_lhs matrix_layout_traits +#define mat_traits_rhs matrix_layout_traits - template - struct is_storage_col_ortho enable_if_t>> - : public true_type - { - }; - - // Classifier for storage of row ortho - template - struct is_storage_row_ortho : public false_type - { - }; + template + struct is_compatible_register_params; - template - struct is_storage_row_ortho, - enable_if_t>> : public true_type - { - }; - - // Classifier for storage of col inline - template - struct is_storage_col_inline : public false_type - { - }; - - template - struct is_storage_col_inline, - enable_if_t>> : public true_type - { - }; - - // Classifier for storage of row inline - template - struct is_storage_row_inline : public false_type - { - }; - - template - struct is_storage_row_inline, - enable_if_t>> : public true_type - { - }; - - // Classifier for mma inputs - template - struct is_mma_input : public false_type + // Compatibility for Storage is a passthrough to MatrixLayout compatibility. + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public is_compatible_matrix_params { }; - template - struct is_mma_input> : public true_type + // Compatibility for MmaInputs + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public integral_constant { }; - // Convenience evaluators - template - constexpr static bool is_storage_col_ortho_v = is_storage_col_ortho::value; - - template - constexpr static bool is_storage_row_ortho_v = is_storage_row_ortho::value; - - template - constexpr static bool is_storage_col_inline_v - = is_storage_col_inline::value; - - template - constexpr static bool is_storage_row_inline_v - = is_storage_row_inline::value; - - template - constexpr static bool is_mma_input_v = is_mma_input::value; - - // Compatibility for Storage, passthrough to MatrixLayout compatibility. - template - struct is_compatible_params, Storage, void> - : public is_compatible_params + // Non-interleaved register layout compatibility with MmaInput requires: + // 1. Inner matrix layout and mma input layout must have same: BlockDim / MmaDim + // 2. MmaDim must satisfy criterion in testSupportedMmaDim(). + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<(reg_traits_lhs::is_storage_layout && !mat_traits_lhs::is_interleaved) + && reg_traits_rhs::is_mma_input_layout>> + : public integral_constant { }; - // Non-interleaved MmaInput layouts require a valid MmaDim (as BlockDim). - // MmaDim values must hold to certain criterion in testSupportedMmaDim(). - template - class MatrixLayout> - struct is_compatible_params< - Storage>, - MmaInput, - enable_if_t> - && testSupportedMmaDim(BlockDim)>> : public true_type + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public integral_constant { }; - template - class MatrixLayout> - struct is_compatible_params< - MmaInput, - Storage>, - enable_if_t> - && testSupportedMmaDim(BlockDim)>> : public true_type + // Interleaved register layout compatibility with MmaInput requires: + // 1. Inner matrix layout and mma input layout must have same: MmaDim + // 2. MmaDim must satisfy criterion in testSupportedMmaDim(). + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<(reg_traits_lhs::is_storage_layout && mat_traits_lhs::is_interleaved) + && reg_traits_rhs::is_mma_input_layout>> + : public integral_constant { }; - // Interleaved MmaInput layouts require a valid MmaDim. - // MmaDim values must hold to certain criterion in testSupportedMmaDim(). - template - class MatrixLayout> - struct is_compatible_params< - Storage>, - MmaInput, - enable_if_t> - && testSupportedMmaDim(MmaSize)>> : public true_type + template + struct is_compatible_register_params< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public integral_constant { }; - template - class MatrixLayout> - struct is_compatible_params< - MmaInput, - Storage>, - enable_if_t> - && testSupportedMmaDim(MmaSize)>> : public true_type - { - }; + // Convenience evaluator + template + constexpr static inline bool is_compatible_register_params_v + = is_compatible_register_params::value; // Checks if both RegisterLayout storages are the same with compatible params template struct is_layout_same< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t< - // Check for same in-register layouts - ((is_storage_col_ortho_v - && is_storage_col_ortho_v) - || (is_storage_row_ortho_v - && is_storage_row_ortho_v) - || (is_storage_col_inline_v - && is_storage_col_inline_v) - || (is_storage_row_inline_v - && is_storage_row_inline_v) - // Check for in-register implicit transposes. These have the same register layouts, but swap meaning - // for rows / cols. - || (is_storage_col_ortho_v - && is_storage_row_ortho_v) - || (is_storage_row_ortho_v - && is_storage_col_ortho_v) - || (is_storage_col_inline_v - && is_storage_row_inline_v) - || (is_storage_row_inline_v - && is_storage_col_inline_v) - // Check for mma input compatibility - || (is_storage_col_ortho_v && is_mma_input_v) - || (is_mma_input_v && is_storage_col_ortho_v) - || (is_storage_row_ortho_v && is_mma_input_v) - || (is_mma_input_v - && is_storage_row_ortho_v)) - && is_compatible_params_v>> : public true_type + enable_if_t> + : public integral_constant< + bool, + // Check for same in-register layouts + ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_inline) + + // Check for in-register implicit transposes. These have the same register layouts, + // but swap meaning for rows / cols. + || (mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_inline) + + // Check mma input sameness + || (reg_traits_lhs::is_mma_input && reg_traits_rhs::is_mma_input) + || (mat_traits_lhs::is_col_ortho && reg_traits_rhs::is_mma_input) + || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_row_ortho && reg_traits_rhs::is_mma_input) + || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_row_ortho)) + && is_compatible_register_params_v> { }; @@ -306,38 +350,35 @@ namespace rocwmma struct is_layout_orthogonal< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t<( // Orthogonality in same orientation (e.g., col / row) - (is_storage_col_ortho_v - && is_storage_col_inline_v) - || (is_storage_col_inline_v - && is_storage_col_ortho_v) - || (is_storage_row_ortho_v - && is_storage_row_inline_v) - || (is_storage_row_inline_v - && is_storage_row_ortho_v) - // Orthogonality in opposite orientation (e.g., col vs row) - || (is_storage_col_ortho_v - && is_storage_row_inline_v) - || (is_storage_row_inline_v - && is_storage_col_ortho_v) - || (is_storage_col_inline_v - && is_storage_row_ortho_v) - || (is_storage_row_ortho_v - && is_storage_col_inline_v) - // Mma orthogonality - || (is_storage_col_inline_v - && is_mma_input_v) - || (is_mma_input_v - && is_storage_col_inline_v) - || (is_storage_row_inline_v - && is_mma_input_v) - || (is_mma_input_v - && is_storage_row_inline_v)) - && is_compatible_params_v>> - : public true_type + enable_if_t> + : public integral_constant< + bool, + // Orthogonality in same orientation (e.g., col / row) + ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_inline) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_ortho) + + // Orthogonality in opposite orientation (e.g., col vs row) + || (mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_inline) + + // Check mma input compatibility + || (mat_traits_lhs::is_col_inline && reg_traits_rhs::is_mma_input) + || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_col_inline) + || (mat_traits_lhs::is_row_inline && reg_traits_rhs::is_mma_input) + || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_row_inline)) + && is_compatible_register_params_v> { }; +#undef reg_traits_lhs +#undef reg_traits_rhs +#undef mat_traits_lhs +#undef mat_traits_rhs + // Use generic MatrixLayout orthogonality rules to guide the register layout transpose suggestion template struct orthogonal_layout> diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp new file mode 100644 index 00000000..7f28cc69 --- /dev/null +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -0,0 +1,130 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP +#define ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + namespace RegisterTransform_impl + { + using LayoutTraits_impl::matrix_layout_traits; + using LayoutTraits_impl::register_layout_traits; + +// Keeps things a bit more tidy. Quick access to register layout traits. +#define reg_traits_lhs register_layout_traits +#define reg_traits_rhs register_layout_traits + +// Quick access to matrix layout traits, that are embedded in the register layout traits. +#define mat_traits_lhs matrix_layout_traits +#define mat_traits_rhs matrix_layout_traits + + // Note: If you arrive at an undefined register_transform error, it is likely + // the layout transformation is not currently supported. Need to either implement + // the transform or ensure your layout transform mapping is correct. + template + struct register_layout_transform; + + // No-op transform (same-layout): + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // No-op + return v; + } + }; + + // AOS -> SOA transform (non-interleaved) requirements: + // - Lhs is *Inline + // - layouts are not interleaved + // - layouts are orthogonal + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) + && !mat_traits_lhs::is_interleaved + && is_layout_orthogonal_v>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return Transforma::AosToSoa::exec(forward(v)); + } + }; + + // SOA -> AOS transform (non-interleaved) requirements: + // - Lhs is *Ortho + // - layouts are not interleaved + // - layouts are orthogonal + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) + && !mat_traits_lhs::is_interleaved + && is_layout_orthogonal_v>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return Transforms::SoaToAos::exec(forward(v)); + } + }; + + // Interleaved layout transform: + // - layouts are interleaved + // - layouts are orthogonal + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // TODO: replace with DimPerThread for interleaved. + return interleave(forward(v)); + } + }; + + } // namespace RegisterTransform_impl + +} // namespace rocWMMA + +#endif // ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP diff --git a/library/include/rocwmma/internal/opaque_load.hpp b/library/include/rocwmma/internal/opaque_load.hpp index 3fadf97b..06fdc066 100644 --- a/library/include/rocwmma/internal/opaque_load.hpp +++ b/library/include/rocwmma/internal/opaque_load.hpp @@ -27,7 +27,6 @@ #define ROCWMMA_OPAQUE_LOAD_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "tuple.hpp" #include "types.hpp" #include "vector_iterator.hpp" diff --git a/library/include/rocwmma/internal/opaque_store.hpp b/library/include/rocwmma/internal/opaque_store.hpp index 1f1f9990..7afbd2e7 100644 --- a/library/include/rocwmma/internal/opaque_store.hpp +++ b/library/include/rocwmma/internal/opaque_store.hpp @@ -27,7 +27,6 @@ #define ROCWMMA_OPAQUE_STORE_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "types.hpp" #include "vector_iterator.hpp" @@ -73,10 +72,7 @@ namespace rocwmma using StoreVecTraits = VecTraits; - template + template ROCWMMA_DEVICE static inline auto unroll_right(DataT* dataPtr, Iterator& in, uint32_t ldm, diff --git a/library/include/rocwmma/internal/transforms.hpp b/library/include/rocwmma/internal/transforms.hpp index 1844ca62..82e61dcd 100644 --- a/library/include/rocwmma/internal/transforms.hpp +++ b/library/include/rocwmma/internal/transforms.hpp @@ -63,185 +63,6 @@ namespace rocwmma template using SoaToAos = Driver>; - // Note: If you arrive at an undefined register_transform error, it is likely - // the layout transformation is not currently supported. Need to either implement - // the transform or ensure your layout transform mapping is correct. - template > - struct register_transform; - - // Layouts that are identical do not require register transformations - template - struct register_transform - { - template - ROCWMMA_DEVICE constexpr static inline decltype(auto) - exec(VecT const& v) - { - return v; - } - }; - - /////// To MmaInput /////// - - // ColInlineVW and RowInlineVW layouts are not mma friendly and require Aos->Soa transform. - // Only valid for BlockDims that supported by mma - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::ColInlineVW>, - RegisterLayout::MmaInput, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(RegisterLayout::detail::testSupportedMmaDim(BlockDim), - "Unsupported mma dim"); - - // ColInlineVW -> ColOrthoVW (mma friendly) = AOS -> SOA transform - return AosToSoa::exec(v); - } - }; - - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::RowInlineVW>, - RegisterLayout::MmaInput, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(RegisterLayout::detail::testSupportedMmaDim(BlockDim), - "Unsupported mma dim"); - - // RowInlineVW -> RowOrthoVW (mma friendly) = AOS -> SOA transform - return AosToSoa::exec(v); - } - }; - - /////// To Other Layouts /////// - - // In-register layouts for (RowInlineVW and RowOrthoVW), and (ColInlineVW and ColOrthoVW) are orthgonal - // and need specific transforms to transition between either representation. - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::RowInlineVW>, - RegisterLayout::Storage< - MatrixLayout::RowOrthoVW>, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(RegisterLayout::detail::testSupportedVW( - MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), - "Invalid VW"); - - // RowInlineVW -> RowOrthoVW = AOS -> SOA transform - return AosToSoa::exec(v); - } - }; - - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::RowOrthoVW>, - RegisterLayout::Storage< - MatrixLayout::RowInlineVW>, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(RegisterLayout::detail::testSupportedVW( - MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), - "Invalid VW"); - - // RowOrthoVW -> RowInlineVW = SOA -> AOS transform - return SoaToAos::exec(v); - } - }; - - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::ColInlineVW>, - RegisterLayout::Storage< - MatrixLayout::ColOrthoVW>, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(RegisterLayout::detail::testSupportedVW( - MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), - "Invalid VW"); - - // ColInlineVW -> ColOrthoVW = AOS -> SOA transform - return AosToSoa::exec(v); - } - }; - - template - struct register_transform< - RegisterLayout::Storage< - MatrixLayout::ColOrthoVW>, - RegisterLayout::Storage< - MatrixLayout::ColInlineVW>, - false_type> - { - // TODO: Remove DataT from the transform - template - ROCWMMA_DEVICE constexpr static inline auto exec(VecT const& v) - { - static_assert(0, "Nope"); - static_assert(RegisterLayout::detail::testSupportedVW( - MaxVectorWidth, VectorWidthLhs, VectorWidthRhs), - "Invalid VW"); - - // ColOrthoVW -> ColInlineVW = SOA -> AOS transform - return SoaToAos::exec(v); - } - }; - } // namespace Transforms } // namespace rocwmma diff --git a/library/include/rocwmma/internal/utility/type_traits.hpp b/library/include/rocwmma/internal/utility/type_traits.hpp index ac42d080..595b40eb 100644 --- a/library/include/rocwmma/internal/utility/type_traits.hpp +++ b/library/include/rocwmma/internal/utility/type_traits.hpp @@ -70,7 +70,7 @@ namespace rocwmma // TODO: override namespace not detail using __hip_internal::is_standard_layout; using __hip_internal::is_trivial; - + using detail::is_void; using detail::is_void_v; using detail::remove_const; @@ -85,9 +85,18 @@ namespace rocwmma using detail::remove_volatile_t; using detail::true_type; + // TODO: goes into algorithm using detail::max; using detail::min; + // TODO: goes into functional + using detail::logical_or; + //using detail::logical_or_v; + using detail::logical_and; + //using detail::logical_and_v; + using detail::logical_not; + //using detail::logical_not_v; + } // namespace rocwmma #define ROCWMMA_TYPE_TRAITS_IMPL_NAMESPACE rocwmma::detail @@ -95,6 +104,10 @@ namespace rocwmma #else #include + +// TODO: move to own files +#include +#include namespace rocwmma { // std implementations @@ -147,9 +160,18 @@ namespace rocwmma using std::remove_volatile_t; using std::true_type; + // TODO: goes into algorithm using std::max; using std::min; + // TODO: goes into functional + using std::logical_or; + //using std::logical_or_v; + using std::logical_and; + //using std::logical_and_v; + using std::logical_not; + //using std::logical_not_v; + } // namespace rocwmma #define ROCWMMA_TYPE_TRAITS_IMPL_NAMESPACE std @@ -159,13 +181,13 @@ namespace rocwmma // Define some convenience traits namespace rocwmma { - template + template using enable_if_integral_t = enable_if_t{}>; - - template + + template using enable_if_signed_t = enable_if_t{}>; - template + template using enable_if_arithmetic_t = enable_if_t{}>; } diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index b53b619c..8726f509 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -39,7 +39,7 @@ #include "internal/io_layout.hpp" #include "internal/io_shape.hpp" #include "internal/io_traits.hpp" -#include "internal/layout.hpp" +#include "internal/layout/layout.hpp" #include "internal/mapping_util.hpp" #include "internal/mfma.hpp" #include "internal/opaque_load.hpp" @@ -341,12 +341,12 @@ namespace rocwmma static_assert(IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim, "KDim of input fragments must match"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Input fragment matrix layouts are not orthogonal"); - static_assert(is_same_v, + static_assert(is_layout_same_v, "Input fragment register layouts do not match"); // static_assert(is_same_v, + static_assert(is_layout_orthogonal_v, "Data Layouts are not orthogonal"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Matrix Layouts are not orthogonal"); - static_assert(is_same_v, + static_assert(is_layout_same_v, "Register layouts do not match"); public: @@ -140,16 +140,16 @@ namespace rocwmma static_assert(IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim, "KDim of transposed frag doesn't match"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Data Layouts are not orthogonal"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Matrix Layouts are not orthogonal"); - static_assert(is_same_v, + static_assert(is_layout_same_v, "Register layouts do not match"); public: @@ -230,8 +230,8 @@ namespace rocwmma // Optimal case: input and output register layouts match template - && is_same_v, + enable_if_t + && is_layout_same_v, int> = 0> ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(FragT const& frag) From f27ed38ecbd6e6612093dfa1dcd25db6f80c695d Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 15 Oct 2024 23:45:24 +0000 Subject: [PATCH 07/36] Fixes build after layout folder refactor --- .../include/rocwmma/internal/io_layout.hpp | 3 +- ...traits.hpp => data_layout_traits_impl.hpp} | 10 +-- .../rocwmma/internal/layout/layout_traits.hpp | 59 +++++++++-------- ...aits.hpp => matrix_layout_traits_impl.hpp} | 28 ++++---- ...ts.hpp => register_layout_traits_impl.hpp} | 65 ++++++++++--------- .../layout/register_layout_transforms.hpp | 18 ++++- .../rocwmma/rocwmma_transforms_impl.hpp | 23 +++---- samples/perf_hgemm.cpp | 18 ++--- test/unit/layout_test/device/col_layout.hpp | 30 ++++----- test/unit/layout_test/device/colnt_layout.hpp | 28 ++++---- test/unit/layout_test/device/row_layout.hpp | 30 ++++----- test/unit/layout_test/device/rownt_layout.hpp | 32 +++++---- test/unit/tuple_test/device/tuple.hpp | 3 +- 13 files changed, 178 insertions(+), 169 deletions(-) rename library/include/rocwmma/internal/layout/{data_layout_traits.hpp => data_layout_traits_impl.hpp} (93%) rename library/include/rocwmma/internal/layout/{matrix_layout_traits.hpp => matrix_layout_traits_impl.hpp} (96%) rename library/include/rocwmma/internal/layout/{register_layout_traits.hpp => register_layout_traits_impl.hpp} (89%) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index ebcea724..ca88bcc8 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include "api_fwd.hpp" #include "constants.hpp" +#include "layout/layout.hpp" #include "layout/layout_profile.hpp" #include "types.hpp" diff --git a/library/include/rocwmma/internal/layout/data_layout_traits.hpp b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp similarity index 93% rename from library/include/rocwmma/internal/layout/data_layout_traits.hpp rename to library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp index 6fcb50b0..c86def24 100644 --- a/library/include/rocwmma/internal/layout/data_layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,8 +23,8 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef ROCWMMA_DATA_LAYOUT_TRAITS_HPP -#define ROCWMMA_DATA_LAYOUT_TRAITS_HPP +#ifndef ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP #include "layout.hpp" #include "layout_traits.hpp" @@ -93,7 +93,7 @@ namespace rocwmma template struct orthogonal_layout> { - using Type + using type = DataLayout::template Array1d::type>; }; @@ -101,4 +101,4 @@ namespace rocwmma } // namespace rocwmma -#endif // ROCWMMA_DATA_LAYOUT_TRAITS_HPP +#endif // ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp index 554952ea..82a2411d 100644 --- a/library/include/rocwmma/internal/layout/layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,18 +29,19 @@ // Need strict inclusion order here // clang-format off #include "layout_traits_impl.hpp" -#include "data_layout_traits.hpp" -#include "matrix_layout_traits.hpp" -#include "register_layout_traits.hpp" +#include "data_layout_traits_impl.hpp" +#include "matrix_layout_traits_impl.hpp" +#include "register_layout_traits_impl.hpp" // clang-format on namespace rocwmma { /*! \class is_layout_same - * \brief Compares layout types are the same, or are equivalent. Similar to is_same, - * however layouts can have an equivalency with small variations input parameters such that they - * are still technically the same. This should be used when comparing any layout types: - * DataLayout, MatrixLayout and RegisterLayout + * \brief Compares layout types are the same, or are equivalent. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout. + * DataLayouts are same if they have the same 1D layout in memory. + * MatrixLayouts are the same if they have the same 2D matrix layout in memory. + * RegisterLayouts are the same if they have the same thread mapping in register. * @tparam LhsLayout Comparative left hand side * @tparam RhsLayout Comparative right hand side */ @@ -49,8 +50,21 @@ namespace rocwmma { }; + /*! \class is_layout_same_v + * \brief Evaluates is_layout_same + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + constexpr static inline bool is_layout_same_v = is_layout_same::value; + /*! \class is_layout_orthogonal * \brief Compares layout types if they are orthogonal with each other. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout + * DataLayouts are orthogonal if their 1D layout in memory is opposite (e.g., row major vs col major). + * MatrixLayouts are orthogonal if their 2D matrix layout geometry is transposed. + * RegisterLayouts are orthogonal if they have opposite per-thread mappings: + * Contiguous vector elements in BlockDim (e.g., AOS) vs contiguous vector elements in kDim (e.g., SOA). * @tparam LhsLayout Comparative left hand side * @tparam RhsLayout Comparative right hand side */ @@ -60,23 +74,6 @@ namespace rocwmma { }; - /*! \class orthogonal_layout - * \brief Transforms the layout type into its orthogonal layout. - * @tparam Layout the layout to transpose from - */ - template - struct orthogonal_layout : public LayoutTraits_impl::orthogonal_layout - { - }; - - /*! \class is_layout_same_v - * \brief Evaluates is_layout_same - * @tparam LhsLayout Comparative left hand side - * @tparam RhsLayout Comparative right hand side - */ - template - constexpr static inline bool is_layout_same_v = is_layout_same::value; - /*! \class is_layout_orthogonal * \brief Evaluates is_layout_orthogonal * @tparam LhsLayout Comparative left hand side @@ -86,9 +83,19 @@ namespace rocwmma constexpr static inline bool is_layout_orthogonal_v = is_layout_orthogonal::value; + /*! \class orthogonal_layout + * \brief Provides a guide to an orthogonal layout of the source layout. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout + * @tparam Layout the source layout + */ + template + struct orthogonal_layout : public LayoutTraits_impl::orthogonal_layout + { + }; + /*! \class layout_transpose_t * \brief Transforms the layout type into its orthogonal layout. - * @tparam Layout the layout to transpose from + * @tparam Layout the source layout */ template using orthogonal_layout_t = typename orthogonal_layout::type; diff --git a/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp similarity index 96% rename from library/include/rocwmma/internal/layout/matrix_layout_traits.hpp rename to library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp index ad1e42f1..3ef248fd 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,8 +23,8 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP -#define ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP +#ifndef ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP #include "../config.hpp" #include "layout.hpp" @@ -246,7 +246,9 @@ namespace rocwmma // will be reflective of the input template params of the MatrixLayout instance. template - struct matrix_layout_base_traits; + struct matrix_layout_base_traits + { + }; // Represent non-interleaved MatrixLayout instances template > : public integral_constant< bool, - ((mat_traits_lhs::is_col_ortho_v && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_row_ortho_v && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_col_inline_v && mat_traits_rhs::is_col_inline) - || (mat_traits_lhs::is_row_inline_v && mat_traits_rhs::is_row_inline)) + ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_inline)) && is_compatible_matrix_params_v> { }; @@ -441,10 +443,10 @@ namespace rocwmma enable_if_t> : public integral_constant< bool, - ((mat_traits_lhs::is_col_ortho_v && mat_traits_rhs::is_row_ortho_v) - || (mat_traits_lhs::is_row_ortho_v && mat_traits_rhs::is_col_ortho_v) - || (mat_traits_lhs::is_col_inline_v && mat_traits_rhs::is_row_inline_v) - || (mat_traits_lhs::is_row_inline_v && mat_traits_rhs::is_col_inline_v)) + ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_ortho) + || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_ortho) + || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_inline) + || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_inline)) && is_compatible_matrix_params_v> { }; @@ -539,4 +541,4 @@ namespace rocwmma } // namespace rocwmma -#endif // ROCWMMA_MATRIX_LAYOUT_TRAITS_HPP +#endif // ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_traits.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp similarity index 89% rename from library/include/rocwmma/internal/layout/register_layout_traits.hpp rename to library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 72c76e40..69752ce6 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,8 +23,8 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP -#define ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP +#ifndef ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP #include "../utility/type_traits.hpp" #include "layout.hpp" @@ -41,9 +41,9 @@ namespace rocwmma // Start to build a basic set of meta-data classifiers. // We will be interested in knowing things about our register layouts: // - is_register_layout - // - is_storage_layout - // - is_mma_input_layout - // - is_mma_acc_layout + // - is_storage + // - is_mma_input + // - is_mma_acc template struct is_register_layout : public false_type { @@ -65,32 +65,32 @@ namespace rocwmma }; template - struct is_storage_layout : public false_type + struct is_storage : public false_type { }; template - struct is_storage_layout> : public is_matrix_layout + struct is_storage> : public is_matrix_layout { }; template - struct is_mma_input_layout : public false_type + struct is_mma_input : public false_type { }; template - struct is_mma_input_layout> : public true_type + struct is_mma_input> : public true_type { }; template - struct is_mma_acc_layout : public false_type + struct is_mma_acc : public false_type { }; template - struct is_mma_acc_layout> : public true_type + struct is_mma_acc> : public true_type { }; @@ -100,19 +100,20 @@ namespace rocwmma = is_register_layout::value; template - constexpr inline static bool is_storage_layout_v = is_storage_layout::value; + constexpr inline static bool is_storage_v = is_storage::value; template - constexpr inline static bool is_mma_input_layout_v - = is_mma_input_layout::value; + constexpr inline static bool is_mma_input_v = is_mma_input::value; template - constexpr inline static bool is_mma_acc_layout_v = is_mma_acc_layout::value; + constexpr inline static bool is_mma_acc_v = is_mma_acc::value; // Next we can build a set of base trait accessors for the RegisterLayout. These // will be reflective of the input template params of the RegisterLayout instance. template - struct register_layout_base_traits; + struct register_layout_base_traits + { + }; template struct register_layout_base_traits> @@ -138,10 +139,10 @@ namespace rocwmma template struct register_layout_traits : public register_layout_base_traits { - constexpr static bool is_register_layout = is_register_layout_v; - constexpr static bool is_storage_layout = is_storage_layout_v; - constexpr static bool is_mma_input_layout = is_mma_input_layout_v; - constexpr static bool is_mma_acc_layout = is_mma_acc_layout_v; + constexpr static bool is_register_layout = is_register_layout_v; + constexpr static bool is_storage = is_storage_v; + constexpr static bool is_mma_input = is_mma_input_v; + constexpr static bool is_mma_acc = is_mma_acc_v; }; // NOTE: RegisterLayout assumptions @@ -237,7 +238,7 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public is_compatible_matrix_params { @@ -248,7 +249,7 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant @@ -262,8 +263,8 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t<(reg_traits_lhs::is_storage_layout && !mat_traits_lhs::is_interleaved) - && reg_traits_rhs::is_mma_input_layout>> + enable_if_t<(reg_traits_lhs::is_storage && !mat_traits_lhs::is_interleaved) + && reg_traits_rhs::is_mma_input>> : public integral_constant @@ -274,8 +275,8 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant @@ -289,8 +290,8 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t<(reg_traits_lhs::is_storage_layout && mat_traits_lhs::is_interleaved) - && reg_traits_rhs::is_mma_input_layout>> + enable_if_t<(reg_traits_lhs::is_storage && mat_traits_lhs::is_interleaved) + && reg_traits_rhs::is_mma_input>> : public integral_constant @@ -301,8 +302,8 @@ namespace rocwmma struct is_compatible_register_params< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant @@ -390,4 +391,4 @@ namespace rocwmma } // namespace rocwmma -#endif // ROCWMMA_REGISTER_LAYOUT_TRAITS_HPP +#endif // ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 7f28cc69..d2beada3 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -80,7 +80,7 @@ namespace rocwmma template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) { - return Transforma::AosToSoa::exec(forward(v)); } }; @@ -123,8 +123,22 @@ namespace rocwmma } }; +#undef reg_traits_lhs +#undef reg_traits_rhs +#undef mat_traits_lhs +#undef mat_traits_rhs + } // namespace RegisterTransform_impl + /*! \class register_layout_transform + * \brief Invokes an in-register transform from one register layout to the other + * @tparam RegisterLayoutLhs Source register layout + * @tparam RegisterLayoutRhs Target register layout + */ + template + using register_layout_transform + = RegisterTransform_impl::register_layout_transform; + } // namespace rocWMMA #endif // ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP diff --git a/library/include/rocwmma/rocwmma_transforms_impl.hpp b/library/include/rocwmma/rocwmma_transforms_impl.hpp index 9a24b8b9..cd18da5d 100644 --- a/library/include/rocwmma/rocwmma_transforms_impl.hpp +++ b/library/include/rocwmma/rocwmma_transforms_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,8 @@ #ifndef ROCWMMA_TRANSFORMS_API_IMPL_HPP #define ROCWMMA_TRANSFORMS_API_IMPL_HPP +#include "internal/layout/layout.hpp" +#include "internal/layout/register_layout_transforms.hpp" #include "internal/transforms.hpp" #include "rocwmma_transforms.hpp" @@ -230,7 +232,7 @@ namespace rocwmma // Optimal case: input and output register layouts match template + enable_if_t && is_layout_same_v, int> = 0> @@ -243,7 +245,7 @@ namespace rocwmma template - && !is_same_v, + && !is_layout_same_v, int> = 0> ROCWMMA_DEVICE constexpr static inline auto exec(FragT const& frag) @@ -263,18 +265,9 @@ namespace rocwmma using DstRegLayout = typename GetCoopIOConfig_t::IOLayout::RegisterLayout; - auto result = FragOut{ - Transforms::register_transform::exec(frag.mAccess)}; - //result.mAccess = ; - - // if constexpr(is_same_v) - // { - // result.mAccess = Transforms::AosToSoa::exec(frag.mAccess); - // } - // else if constexpr(is_same_v) - // { - // result.mAccess = Transforms::SoaToAos::exec(frag.mAccess); - // } + auto result = FragOut{}; + result.mAccess + = register_layout_transform::exec(frag.mAccess); return result; } diff --git a/samples/perf_hgemm.cpp b/samples/perf_hgemm.cpp index 2c9f784a..fd469567 100644 --- a/samples/perf_hgemm.cpp +++ b/samples/perf_hgemm.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -375,8 +375,8 @@ ROCWMMA_DEVICE static inline void { constexpr uint32_t VW = 4; - using Profile = rocwmma::LayoutProfile:: - ColInt; + using Profile + = rocwmma::LayoutProfile::ColInt; using DataLayout = typename Profile::DataLayout; using MatrixLayout = typename Profile::MatrixLayout; @@ -404,8 +404,8 @@ ROCWMMA_DEVICE static inline void // How to choose? Comes from the IOConfig? constexpr uint32_t VW = 4; - using Profile = rocwmma::LayoutProfile:: - ColInt; + using Profile + = rocwmma::LayoutProfile::ColInt; using MatrixLayout = typename Profile::MatrixLayout; using DataLayout = typename Profile::DataLayout; @@ -432,8 +432,8 @@ ROCWMMA_DEVICE static inline void globalReadC(GRBuffC& fragsC, OutputT const* gA // How to choose? Comes from the IOConfig? constexpr uint32_t VW = 4; - using Profile = rocwmma::LayoutProfile:: - RowInt; + using Profile + = rocwmma::LayoutProfile::RowInt; using MatrixLayout = typename Profile::MatrixLayout; using DataLayout = typename Profile::DataLayout; @@ -499,8 +499,8 @@ ROCWMMA_DEVICE static inline void globalWriteD(OutputT* gAddrD, GRBuffC const& f // How to choose? Comes from the IOConfig? constexpr uint32_t VW = 4; - using Profile = rocwmma::LayoutProfile:: - RowInt; + using Profile + = rocwmma::LayoutProfile::RowInt; using MatrixLayout = typename Profile::MatrixLayout; using DataLayout = typename Profile::DataLayout; diff --git a/test/unit/layout_test/device/col_layout.hpp b/test/unit/layout_test/device/col_layout.hpp index 278618c7..aee9b1c3 100644 --- a/test/unit/layout_test/device/col_layout.hpp +++ b/test/unit/layout_test/device/col_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,16 +28,13 @@ #define ROCWMMA_DEVICE_COL_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void ColLayout(uint32_t m, uint32_t n, DataT const* in, @@ -46,12 +43,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -61,8 +58,9 @@ namespace rocwmma }; using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile::Col::MatrixLayout; - using Mapping = MappingUtil; + using LayoutT = typename LayoutProfile:: + Col::MatrixLayout; + using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); auto iocount = IOTraits::IOCount; @@ -79,14 +77,14 @@ namespace rocwmma for(int j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + + Mapping::dataOffset(baseOffset, ld) + j; out[index] = in[index]; } baseOffset += LayoutT::incrementalOffset(i); } } } - + } // namespace rocwmma #endif // ROCWMMA_DEVICE_COL_LAYOUT_HPP diff --git a/test/unit/layout_test/device/colnt_layout.hpp b/test/unit/layout_test/device/colnt_layout.hpp index 07e1399a..e19c1165 100644 --- a/test/unit/layout_test/device/colnt_layout.hpp +++ b/test/unit/layout_test/device/colnt_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,16 +28,13 @@ #define ROCWMMA_DEVICE_COLNT_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void ColNTLayout(uint32_t m, uint32_t n, DataT const* in, @@ -46,12 +43,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -61,8 +58,9 @@ namespace rocwmma }; using IOTraits = IOTraits; - using LayoutT - = typename LayoutProfile::ColNT::MatrixLayout; + using LayoutT = typename LayoutProfile:: + ColNT:: + MatrixLayout; using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); @@ -80,7 +78,7 @@ namespace rocwmma for(int j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + + Mapping::dataOffset(baseOffset, ld) + j; out[index] = in[index]; } baseOffset += LayoutT::incrementalOffset(i); diff --git a/test/unit/layout_test/device/row_layout.hpp b/test/unit/layout_test/device/row_layout.hpp index ac12ad23..bd7d2106 100644 --- a/test/unit/layout_test/device/row_layout.hpp +++ b/test/unit/layout_test/device/row_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,17 +28,14 @@ #define ROCWMMA_DEVICE_ROW_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void RowLayout(uint32_t m, uint32_t n, DataT const* in, @@ -47,12 +44,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -68,8 +65,9 @@ namespace rocwmma }; using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile::Row::MatrixLayout; - using Mapping = MappingUtil; + using LayoutT = typename LayoutProfile:: + Row::MatrixLayout; + using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); auto iocount = IOTraits::IOCount; @@ -86,14 +84,14 @@ namespace rocwmma for(uint32_t j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + + Mapping::dataOffset(baseOffset, ld) + j; out[index] = in[index]; } baseOffset += LayoutT::incrementalOffset(i); } } } - + } // namespace rocwmma #endif // ROCWMMA_DEVICE_ROW_LAYOUT_HPP diff --git a/test/unit/layout_test/device/rownt_layout.hpp b/test/unit/layout_test/device/rownt_layout.hpp index beaae4b7..afc2fab2 100644 --- a/test/unit/layout_test/device/rownt_layout.hpp +++ b/test/unit/layout_test/device/rownt_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,17 +28,14 @@ #define ROCWMMA_DEVICE_ROWNT_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void RowNTLayout(uint32_t m, uint32_t n, DataT const* in, @@ -47,12 +44,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -64,14 +61,15 @@ namespace rocwmma MaxVectorWidth = std::is_same_v - ? 1 - : detail::MaxVWSelector::Result, + ? 1 + : detail::MaxVWSelector::Result, VectorWidth = std::is_same_v ? MaxVectorWidth : 1, }; using IOTraits = IOTraits; - using LayoutT - = typename LayoutProfile::RowNT::MatrixLayout; + using LayoutT = typename LayoutProfile:: + RowNT:: + MatrixLayout; using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); @@ -89,7 +87,7 @@ namespace rocwmma for(uint32_t j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + + Mapping::dataOffset(baseOffset, ld) + j; out[index] = in[index]; } baseOffset += LayoutT::incrementalOffset(i); diff --git a/test/unit/tuple_test/device/tuple.hpp b/test/unit/tuple_test/device/tuple.hpp index cce72d5a..8de10318 100644 --- a/test/unit/tuple_test/device/tuple.hpp +++ b/test/unit/tuple_test/device/tuple.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,6 @@ #define ROCWMMA_DEVICE_TUPLE_TEST_HPP #include -#include #include #include #include From f3622e30fbbf71358952d7e9aefc00bb55a35e0f Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 24 Oct 2024 21:37:13 +0000 Subject: [PATCH 08/36] Update interleaving function --- .../rocwmma/internal/vector_util_impl.hpp | 48 ++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/library/include/rocwmma/internal/vector_util_impl.hpp b/library/include/rocwmma/internal/vector_util_impl.hpp index 1096757e..467b858b 100644 --- a/library/include/rocwmma/internal/vector_util_impl.hpp +++ b/library/include/rocwmma/internal/vector_util_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -416,18 +416,44 @@ namespace rocwmma } } - template - ROCWMMA_DEVICE constexpr static inline auto interleave(VecT const& v0) + // A permutation of vector indices, given a gather size and a stride + // Examples: + // row_major col_major + // [0, 1] => interleave<1, 2>([0, 1, 2, 3, 4, 5]) = [0, 2, 4, 1, 3, 5] + // A = [2, 3] col_major row_major + // [4, 5] => interleave<1, 4>([0, 2, 4, 1, 3, 5]) = [0, 1, 2, 3, 4, 5] + // + // [0, 1] + // A = [2, 3] => interleave<2, 4>([0, 1, 2, 3, 4, 5, 6, 7]) = [0, 1, 4, 5, 2, 3, 6, 7] + // [4, 5] + // [6, 7] + // + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) interleave(VecT const& v0) { - // Interleave groups - auto offset = [](auto&& idx, auto&& v0) { - constexpr auto Index = decay_t::value; - constexpr auto Offset0 = Index * GroupSize; - constexpr auto Offset1 = Index / (VecSize / GroupSize); - return get<(Offset0 + Offset1) % VecSize>(v0); - }; + static_assert((GatherSize >= 1u) && (GatherSize <= ElementStride) + && (ElementStride % GatherSize == 0) && (VecSize % GatherSize == 0), + "Invalid GatherSize"); + static_assert(ElementStride >= 1u && ElementStride < VecSize, "Invalid Stride"); + + // No transform is needed (NOP) + if constexpr(GatherSize == ElementStride || ElementStride == VecSize) + { + return v0; + } + else + { + auto offset = [](auto&& idx, auto&& v0) { + constexpr auto Index = decay_t::value; + constexpr auto Offset0 = (Index / GatherSize) * ElementStride % VecSize; + constexpr auto Offset1 = Index % GatherSize; + constexpr auto Offset2 + = (Index * ElementStride) / (VecSize * GatherSize) * GatherSize; + return get(v0); + }; - return vector_generator()(offset, v0); + return vector_generator()(offset, v0); + } } } // namespace rocwmma From 2c550c39b57f474af72942f6f6d7e5b8e1acdb8f Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Sat, 9 Nov 2024 01:24:26 +0000 Subject: [PATCH 09/36] Update is_layout_same and is_layout_orthogonal and matrix layouts logic --- library/include/rocwmma/internal/config.hpp | 15 +- .../layout/data_layout_traits_impl.hpp | 126 +++- .../rocwmma/internal/layout/layout.hpp | 84 ++- .../internal/layout/layout_profile.hpp | 14 +- .../rocwmma/internal/layout/layout_traits.hpp | 15 +- .../internal/layout/layout_traits_impl.hpp | 6 +- .../internal/layout/matrix_layout_impl.hpp | 566 ++++++++++------ .../layout/matrix_layout_traits_impl.hpp | 301 +++++---- .../layout/register_layout_traits_impl.hpp | 617 +++++++++++++----- .../layout/register_layout_transforms.hpp | 92 ++- 10 files changed, 1282 insertions(+), 554 deletions(-) diff --git a/library/include/rocwmma/internal/config.hpp b/library/include/rocwmma/internal/config.hpp index d61dd51e..08b112b5 100644 --- a/library/include/rocwmma/internal/config.hpp +++ b/library/include/rocwmma/internal/config.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -146,6 +146,11 @@ static_assert(0, "Unsupported architecture"); #define ROCWMMA_ARCH_GFX94X 1 #endif +#if ROCWMMA_ARCH_HOST +#define ROCWMMA_BLOCK_DIM_16_SUPPORTED 1 +#define ROCWMMA_BLOCK_DIM_32_SUPPORTED 1 +#endif + #if !defined(ROCWMMA_ARCH_GFX9) #define ROCWMMA_ARCH_GFX9 0 #endif @@ -201,10 +206,10 @@ static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DI #endif #if ROCWMMA_ARCH_GFX12 - static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE), - "rocWMMA supports only wave32 for gfx12 arch"); - static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED), - "rocWMMA supports only block size of 16 for gfx12 arch"); +static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE), + "rocWMMA supports only wave32 for gfx12 arch"); +static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED), + "rocWMMA supports only block size of 16 for gfx12 arch"); #endif /// diff --git a/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp index c86def24..dfdeca09 100644 --- a/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp @@ -33,50 +33,117 @@ namespace rocwmma { namespace LayoutTraits_impl { - // Sameness classifier - template <> - struct is_layout_same : public true_type + // Reference regular layouts + using DataLayout::ColMajor; + using DataLayout::RowMajor; + + // Build a basic set of meta-data classifiers. + // We will be interested in knowing things about our data layouts: + // - is_row_major + // - is_col_major + // - is_data_layout + // + // Note: We will qualify both: + // row_major / col_major (as meta-tags) + // RowMajor and ColMajor (as functional classes) + template + struct is_row_major : public false_type { }; template <> - struct is_layout_same : public true_type + struct is_row_major : public true_type { }; template <> - struct is_layout_same : public true_type + struct is_row_major : public true_type { }; - template <> - struct is_layout_same : public true_type + template + struct is_col_major : public false_type { }; - // Orthogonality classifier template <> - struct is_layout_orthogonal : public true_type + struct is_col_major : public true_type { }; template <> - struct is_layout_orthogonal : public true_type + struct is_col_major : public true_type { }; - template <> - struct is_layout_orthogonal - : public true_type + // Convenience evaluators + template + static constexpr bool is_row_major_v = is_row_major::value; + + template + static constexpr bool is_col_major_v = is_col_major::value; + + template + struct is_data_layout + : public integral_constant || is_col_major_v> { }; - template <> - struct is_layout_orthogonal - : public true_type + // Convenience evaluator + template + static constexpr bool is_data_layout_v = is_data_layout::value; + + // Cumulative traits about our data layouts + template + struct data_layout_traits { + static constexpr bool is_row_major = is_row_major_v; + static constexpr bool is_col_major = is_col_major_v; + static constexpr bool is_data_layout = is_data_layout_v; }; +// Tidy some traits accesses +#define traits_lhs data_layout_traits +#define traits_rhs data_layout_traits + + template + ROCWMMA_HOST_DEVICE constexpr static bool testDataLayoutSame() + { + return (traits_lhs::is_row_major && traits_rhs::is_row_major) + || (traits_lhs::is_col_major && traits_rhs::is_col_major); + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testDataLayoutOrthogonal() + { + return (traits_lhs::is_row_major && traits_rhs::is_col_major) + || (traits_lhs::is_col_major && traits_rhs::is_row_major); + } + + // Implement sameness classifier for data layouts + template + struct is_layout_same> + : public integral_constant()> + { + }; + + // Implement orthogonality classifier for data layouts + template + struct is_layout_orthogonal< + DataLayoutLhs, + DataLayoutRhs, + enable_if_t> + : public integral_constant()> + { + }; + +#undef traits_lhs +#undef traits_rhs + // Orthogonal layout guides template <> struct orthogonal_layout @@ -97,8 +164,35 @@ namespace rocwmma = DataLayout::template Array1d::type>; }; + template + struct layout_traits>> + : public data_layout_traits + { + }; + } // namespace LayoutTraits_impl } // namespace rocwmma +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::data_layout_traits const& traits) + { + using data_traits = decay_t; + stream << "DataLayout Traits: " << DataLayout{} << std::endl; + stream << "is_row_major: " << data_traits::is_row_major << std::endl; + stream << "is_col_major: " << data_traits::is_col_major << std::endl; + stream << "is_data_layout: " << data_traits::is_data_layout << std::endl; + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + #endif // ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index a860b61c..d85b0c77 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -144,20 +144,30 @@ namespace rocwmma // Storage to MmaInput<16> to serve as input to a 16x16xk mma builtin. namespace RegisterLayout { + // Format for data locality + enum struct Format : uint32_t + { + SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] + AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] + None = 2u, + }; + // A mnemonic used to describe the register layout is suitable for input/output - template + template struct Storage { }; // A mnemonic used to describe the register layout is suitable for mma input for A/B - template + template struct MmaInput { }; // A mnemonic used to describe the register layout is suitable for mma input for accumulator input/output - template + template struct MmaAcc { }; @@ -166,6 +176,72 @@ namespace rocwmma } // namespace rocwmma +#if !defined(__HIPCC_RTC__) +namespace std +{ + + inline ostream& operator<<(ostream& stream, rocwmma::row_major const& data_layout) + { + return stream << "row_major"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::col_major const& data_layout) + { + return stream << "col_major"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::DataLayout::RowMajor const& data_layout) + { + return stream << "RowMajor"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::DataLayout::ColMajor const& data_layout) + { + return stream << "ColMajor"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#if !defined(__HIPCC_RTC__) +namespace std +{ + inline ostream& operator<<(ostream& stream, rocwmma::RegisterLayout::Format const& fmt) + { + return stream << (fmt == rocwmma::RegisterLayout::Format::AOS ? "AOS" + : (fmt == rocwmma::RegisterLayout::Format::SOA) ? "SOA" + : "NONE"); + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::Storage const& register_layout) + { + return stream << "Storage<" << MatrixLayout{} << ", " << DataLayout{} << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::MmaInput const& register_layout) + { + return stream << "MmaInput<" << MmaDim << ", " << Interleaved << ", " << Fmt << ">"; + } + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::RegisterLayout::MmaAcc const& register_layout) + { + return stream << "MmaAcc<" << MmaDim << ", " << Interleaved << ", " << Fmt << ">"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + #include "matrix_layout_impl.hpp" #endif // ROCWMMA_LAYOUT_HPP diff --git a/library/include/rocwmma/internal/layout/layout_profile.hpp b/library/include/rocwmma/internal/layout/layout_profile.hpp index cebab878..8ab2ec69 100644 --- a/library/include/rocwmma/internal/layout/layout_profile.hpp +++ b/library/include/rocwmma/internal/layout/layout_profile.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -60,7 +60,7 @@ namespace rocwmma MatrixLayout::ColOrthoVW, MatrixLayout::ColOrthoVW>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -104,7 +104,7 @@ namespace rocwmma MatrixLayout::RowOrthoVW, MatrixLayout::RowOrthoVW>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -149,7 +149,7 @@ namespace rocwmma MatrixLayout::ColInlineVW, MatrixLayout::ColOrthoVW>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -188,7 +188,7 @@ namespace rocwmma MatrixLayout::RowInlineVW, MatrixLayout::RowOrthoVW>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -229,7 +229,7 @@ namespace rocwmma MatrixLayout::ColInlineInt, MatrixLayout::ColOrthoInt>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; @@ -269,7 +269,7 @@ namespace rocwmma MatrixLayout::RowInlineInt, MatrixLayout::RowOrthoInt>; - using RegisterLayout = RegisterLayout::Storage; + using RegisterLayout = RegisterLayout::Storage; // Mapping using MappingUtil = MappingUtil; diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp index 82a2411d..ed775356 100644 --- a/library/include/rocwmma/internal/layout/layout_traits.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -59,12 +59,12 @@ namespace rocwmma constexpr static inline bool is_layout_same_v = is_layout_same::value; /*! \class is_layout_orthogonal - * \brief Compares layout types if they are orthogonal with each other. + * \brief Describes a relationship between two layout endpoints. They are considered orthogonal if they + * are not the same, and there exists a reversible transformation path from one to the other. * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout - * DataLayouts are orthogonal if their 1D layout in memory is opposite (e.g., row major vs col major). - * MatrixLayouts are orthogonal if their 2D matrix layout geometry is transposed. - * RegisterLayouts are orthogonal if they have opposite per-thread mappings: - * Contiguous vector elements in BlockDim (e.g., AOS) vs contiguous vector elements in kDim (e.g., SOA). + * DataLayouts are orthogonal if their 1D layouts in memory are transformable (e.g., row major vs col major). + * MatrixLayouts are orthogonal if their 2D matrix layout geometry is transformable (e.g., layout transpose). + * RegisterLayouts are orthogonal if their in-register layouts are transformable (e.g., AOS vs SOA) * @tparam LhsLayout Comparative left hand side * @tparam RhsLayout Comparative right hand side */ @@ -100,6 +100,11 @@ namespace rocwmma template using orthogonal_layout_t = typename orthogonal_layout::type; + template + struct layout_traits : public LayoutTraits_impl::layout_traits + { + }; + } // namespace rocwmma #endif // ROCWMMA_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp index 5e84fd46..42fbaec4 100644 --- a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,6 +48,10 @@ namespace rocwmma template struct orthogonal_layout; + // Meta traits for layouts + template + struct layout_traits; + } // namespace LayoutTraits_impl } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp index d2d1c84e..a14e2cfe 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -96,85 +96,85 @@ namespace rocwmma uint32_t MaxVectorWidth> struct ColOrthoVW { - using IOTraits = IOTraits; struct Traits { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; + + // Stride between tiles + static constexpr uint32_t BlockDimStride_X = min(BlockDim, WaveSize); + static constexpr uint32_t BlockDimStride_Y = 0u; - // Strides - BlockDimStride_X = min(BlockDim, WaveSize), - BlockDimStride_Y = 0u, + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y + = WaveSize * MaxVectorWidth / BlockDimStride_X; - BlockKStride_X = 0u, - BlockKStride_Y = WaveSize * MaxVectorWidth / BlockDimStride_X, + static constexpr uint32_t VWStride_X = 0u; + static constexpr uint32_t VWStride_Y = VectorWidth; - VWStride_X = 0u, - VWStride_Y = VectorWidth, + // Stride space + static constexpr uint32_t BlockDimSegs = BlockDim / BlockDimStride_X; + static constexpr uint32_t BlockKSegs = BlockK / BlockKStride_Y; + static constexpr uint32_t VWSegs = MaxVectorWidth / VWStride_Y; - // Stride space - BlockDimSegs = BlockDim / BlockDimStride_X, - BlockKSegs = BlockK / BlockKStride_Y, - VWSegs = MaxVectorWidth / VWStride_Y, - }; + // Thread-tile perspective + // TODO: rename to ThreadTile... + static constexpr uint32_t DimPerThread = BlockKSegs; + static constexpr uint32_t KPerThread = MaxVectorWidth; + static constexpr uint32_t ElementsPerThread + = DimPerThread * KPerThread * BlockDimSegs; - static_assert(BlockDim >= (uint32_t)Traits::BlockDimStride_X, + static_assert(BlockDim >= BlockDimStride_X, "BlockDim must be larger than BlockDimStride_X"); - static_assert(BlockDim % (uint32_t)Traits::BlockDimStride_X == 0, + static_assert(BlockDim % BlockDimStride_X == 0, "BlockDim must be a multiple of BlockDimStride_X"); - static_assert(BlockK >= (uint32_t)Traits::BlockKStride_Y, + static_assert(BlockK >= BlockKStride_Y, "BlockK must be larger than BlockKStride_Y"); - static_assert(BlockK % (uint32_t)Traits::BlockKStride_Y == 0, + static_assert(BlockK % BlockKStride_Y == 0, "BlockK must be a multiple of BlockKStride_Y"); - static_assert(MaxVectorWidth >= (uint32_t)Traits::VWStride_Y, + static_assert(MaxVectorWidth >= VWStride_Y, "MaxVectorWidth must larger than VWStride_Y"); - static_assert(MaxVectorWidth % (uint32_t)Traits::VWStride_Y == 0, + static_assert(MaxVectorWidth % VWStride_Y == 0, "MaxVectorWidth must be a multiple of VWStride_Y"); // Orthogonal layout, coordinates are reversed - using OrthoLayout - = RowOrthoVW; + // using OrthoLayout + // = RowOrthoVW; - using MatrixCoordT = Coord2d; + // using MatrixCoordT = Coord2d; }; ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - return make_vector((uint32_t)Traits::BlockDimSegs, // BlockDim Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments + return make_vector(Traits::BlockDimSegs, // BlockDim Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments } ROCWMMA_DEVICE constexpr static inline auto strides() { - return make_vector( - make_coord2d((uint32_t)Traits::BlockDimStride_X, - (uint32_t)Traits::BlockDimStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + return make_vector(make_coord2d(Traits::BlockDimStride_X, Traits::BlockDimStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + ROCWMMA_DEVICE static inline auto baseOffset() { - if constexpr((uint32_t)Traits::BlockDimStride_X >= (uint32_t)Traits::WaveSize) + if constexpr(Traits::BlockDimStride_X >= Traits::WaveSize) { // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, 0u); + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, 0u); } else { // Threads need to spread over the Y direction as well - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, - (threadIdx.x / (uint32_t)Traits::BlockDimStride_X) - * MaxVectorWidth % (uint32_t)Traits::BlockKStride_Y); + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, + (threadIdx.x / Traits::BlockDimStride_X) * MaxVectorWidth + % Traits::BlockKStride_Y); } } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { // Reference: // VWOffsetY = VWStride_Y - ((i+1) % VWSegs ? 0u : VWStride_Y * VWSegs); @@ -212,6 +212,7 @@ namespace rocwmma BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs ? 0 : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::BlockDimSegs > 1) { // "Reset" cycle @@ -242,8 +243,7 @@ namespace rocwmma return make_coord2d(BlockDimOffsetX, VWOffsetY + BlockKOffsetY); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { int32_t cumVWOffsetY = (int32_t)Traits::VWStride_Y * ((int32_t)iteration % (int32_t)Traits::VWSegs); @@ -358,89 +358,89 @@ namespace rocwmma uint32_t MaxVectorWidth> struct ColInlineVW { - using IOTraits = IOTraits; + struct Traits { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; - // Strides - BlockDimStride_X = min(BlockDim, WaveSize), - BlockDimStride_Y = 0u, + // Strides + static constexpr uint32_t BlockDimStride_X = min(BlockDim, WaveSize); + static constexpr uint32_t BlockDimStride_Y = 0u; - BlockKStride_X = 0u, - BlockKStride_Y = WaveSize * MaxVectorWidth / BlockDimStride_X, + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y + = WaveSize * MaxVectorWidth / BlockDimStride_X; - VWStride_X = VectorWidth, - VWStride_Y = 0u, + static constexpr uint32_t VWStride_X = VectorWidth; + static constexpr uint32_t VWStride_Y = 0u; - // Stride Space - BlockDimSegs = BlockDim / BlockDimStride_X, - BlockKSegs = BlockK / BlockKStride_Y, - VWSegs = MaxVectorWidth / VWStride_X, - }; + // Stride Space + static constexpr uint32_t BlockDimSegs = BlockDim / BlockDimStride_X; + static constexpr uint32_t BlockKSegs = BlockK / BlockKStride_Y; + static constexpr uint32_t VWSegs = MaxVectorWidth / VWStride_X; + + // Thread-tile perspective + // TODO: rename to ThreadTile... + static constexpr uint32_t DimPerThread = MaxVectorWidth; + static constexpr uint32_t KPerThread = BlockKSegs; + static constexpr uint32_t ElementsPerThread + = DimPerThread * KPerThread * BlockDimSegs; // Sanity checks for strides sizes - static_assert(BlockDim >= (uint32_t)Traits::BlockDimStride_X, + static_assert(BlockDim >= BlockDimStride_X, "BlockDim must be larger than BlockDimStride_X"); - static_assert(BlockDim % (uint32_t)Traits::BlockDimStride_X == 0, + static_assert(BlockDim % BlockDimStride_X == 0, "BlockDim must be a multiple of BlockDimStride_X"); - static_assert(BlockK >= (uint32_t)Traits::BlockKStride_Y, + static_assert(BlockK >= BlockKStride_Y, "BlockK must be larger than BlockKStride_Y"); - static_assert(BlockK % (uint32_t)Traits::BlockKStride_Y == 0, + static_assert(BlockK % BlockKStride_Y == 0, "BlockK must be a multiple of BlockKStride_Y"); - static_assert(MaxVectorWidth >= (uint32_t)Traits::VWStride_X, + static_assert(MaxVectorWidth >= VWStride_X, "MaxVectorWidth must larger than VWStride_X"); - static_assert(MaxVectorWidth % (uint32_t)Traits::VWStride_X == 0, + static_assert(MaxVectorWidth % VWStride_X == 0, "MaxVectorWidth must be a multiple of VWStride_X"); // Orthogonal layout, coordinates are reversed - using OrthoLayout - = RowInlineVW; + //using OrthoLayout + // = RowInlineVW; - using MatrixCoordT = Coord2d; + //using MatrixCoordT = Coord2d; }; ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - return make_vector((uint32_t)Traits::BlockDimSegs, // BlockDim Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments + return make_vector(Traits::BlockDimSegs, // BlockDim Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments } ROCWMMA_DEVICE constexpr static inline auto strides() { - return make_vector( - make_coord2d((uint32_t)Traits::BlockDimStride_X, - (uint32_t)Traits::BlockDimStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + return make_vector(make_coord2d(Traits::BlockDimStride_X, Traits::BlockDimStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + ROCWMMA_DEVICE static inline auto baseOffset() { - if constexpr(((uint32_t)Traits::BlockDimStride_X >= (uint32_t)Traits::WaveSize) + if constexpr((Traits::BlockDimStride_X >= Traits::WaveSize) && (MaxVectorWidth == 1)) { // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, 0u); + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, 0u); } else { // Threads need to spread over the Y direction as well - return make_coord2d( - threadIdx.x * MaxVectorWidth % (uint32_t)Traits::BlockDimStride_X, - threadIdx.x * MaxVectorWidth / (uint32_t)Traits::BlockDimStride_X - % (uint32_t)Traits::BlockKStride_Y); + return make_coord2d(threadIdx.x * MaxVectorWidth % Traits::BlockDimStride_X, + threadIdx.x * MaxVectorWidth / Traits::BlockDimStride_X + % Traits::BlockKStride_Y); } } // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { // Reference: // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); @@ -509,8 +509,7 @@ namespace rocwmma } // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { int32_t cumVWOffsetX = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); @@ -523,6 +522,7 @@ namespace rocwmma return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); } + ROCWMMA_DEVICE static inline auto debug() {} }; @@ -533,38 +533,34 @@ namespace rocwmma uint32_t SplitK /* = 1*/> // # of splits struct ColInlineInt { - using IOTraits = IOTraits; struct Traits { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = BlockDim / MfmaDim; - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = BlockK * MfmaDim / (WaveSize * SplitK); - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, + // How many elements each thread will gather + static constexpr uint32_t ElementsPerThread = DimPerThread * KPerThread; - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, + // Strides + static constexpr uint32_t SplitKStride_X = 0u; + static constexpr uint32_t SplitKStride_Y = BlockK / SplitK; - BlockKStride_X = 0u, - BlockKStride_Y = 1u, + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y = 1u; - VWStride_X = DimPerThread, - VWStride_Y = 0u, + static constexpr uint32_t VWStride_X = DimPerThread; + static constexpr uint32_t VWStride_Y = 0u; - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = KPerThread / BlockKStride_Y, - VWSegs = DimPerThread / VWStride_X, - }; + // Stride Space + static constexpr uint32_t SplitKSegs = BlockK / SplitKStride_Y; + static constexpr uint32_t BlockKSegs = KPerThread / BlockKStride_Y; + static constexpr uint32_t VWSegs = DimPerThread / VWStride_X; // // Check VectorWidth validity // static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); @@ -572,9 +568,8 @@ namespace rocwmma // "DimPerThread not a multiple of VectorWidth"); // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); + static_assert(BlockK >= KPerThread, "Invalid KPerThread"); + static_assert(BlockK % KPerThread == 0, "BlockK is not a multiple of KPerThread"); // Check SplitK validity static_assert(BlockK >= SplitK, "Invalid SplitK"); @@ -585,39 +580,31 @@ namespace rocwmma static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowInlineInt; + //using OrthoLayout = RowInlineInt; - using MatrixCoordT = Coord2d; + //using MatrixCoordT = Coord2d; }; ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - - return make_vector((uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); + return make_vector(Traits::SplitKSegs, Traits::BlockKSegs, Traits::VWSegs); } ROCWMMA_DEVICE constexpr static inline auto strides() { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + return make_vector(make_coord2d(Traits::SplitKStride_X, Traits::SplitKStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + ROCWMMA_DEVICE static inline auto baseOffset() { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); + return make_coord2d((threadIdx.x * Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * Traits::KPerThread) % BlockK); } // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { // Reference: // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); @@ -686,8 +673,7 @@ namespace rocwmma } // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { int32_t cumVWOffsetX = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); @@ -702,29 +688,29 @@ namespace rocwmma } ROCWMMA_DEVICE static inline auto debug() { - if(threadIdx.x == 0 && threadIdx.y == 0) - { - printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", - (uint32_t)Traits::SplitKSegs, - (uint32_t)Traits::BlockKSegs, - (uint32_t)Traits::VWSegs); - - printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " - "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", - (uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y, - (uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y, - (uint32_t)Traits::VWStride_X, - (uint32_t)Traits::VWStride_Y); - } - if(threadIdx.x <= 63 && threadIdx.y == 0) - { - printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", - threadIdx.x, - get<0>(baseOffset()), - get<1>(baseOffset())); - } + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " + // "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", + // threadIdx.x, + // get<0>(baseOffset()), + // get<1>(baseOffset())); + // } } }; @@ -735,43 +721,38 @@ namespace rocwmma uint32_t SplitK /*= 1*/> // # of splits struct ColOrthoInt { - using IOTraits = IOTraits; struct Traits { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; - // Number of elements each thread will fetch in BlockDim direction - DimPerThread = BlockDim / MfmaDim, + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = BlockDim / MfmaDim; - // Number of elements each thread will fetch in BlockK direction - KPerThread = BlockK * MfmaDim / (WaveSize * SplitK), + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = BlockK * MfmaDim / (WaveSize * SplitK); - // Number of elements that each thread is responsible for - ElementsPerThread = DimPerThread * KPerThread, + // Number of elements that each thread is responsible for + static constexpr uint32_t ElementsPerThread = DimPerThread * KPerThread; - // Strides - SplitKStride_X = 0u, - SplitKStride_Y = BlockK / SplitK, + // Strides + static constexpr uint32_t SplitKStride_X = 0u; + static constexpr uint32_t SplitKStride_Y = BlockK / SplitK; - BlockKStride_X = 1u, - BlockKStride_Y = 0u, + static constexpr uint32_t BlockKStride_X = 1u; + static constexpr uint32_t BlockKStride_Y = 0u; - VWStride_X = 0u, - VWStride_Y = DimPerThread, + static constexpr uint32_t VWStride_X = 0u; + static constexpr uint32_t VWStride_Y = KPerThread; - // Stride Space - SplitKSegs = BlockK / SplitKStride_Y, - BlockKSegs = DimPerThread / BlockKStride_X, - VWSegs = KPerThread / VWStride_Y, - }; + // Stride Space + static constexpr uint32_t SplitKSegs = BlockK / SplitKStride_Y; + static constexpr uint32_t BlockKSegs = DimPerThread / BlockKStride_X; + static constexpr uint32_t VWSegs = KPerThread / VWStride_Y; // Check KPerThread validity - static_assert(BlockK >= (uint32_t)Traits::KPerThread, "Invalid KPerThread"); - static_assert(BlockK % (uint32_t)Traits::KPerThread == 0, - "BlockK is not a multiple of KPerThread"); + static_assert(BlockK >= KPerThread, "Invalid KPerThread"); + static_assert(BlockK % KPerThread == 0, "BlockK is not a multiple of KPerThread"); // // Check VectorWidth validity // static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); @@ -787,38 +768,33 @@ namespace rocwmma static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); // Orthogonal layout, coordinates are reversed - using OrthoLayout = RowOrthoInt; + //using OrthoLayout = RowOrthoInt; - using MatrixCoordT = Coord2d; + //using MatrixCoordT = Coord2d; }; ROCWMMA_DEVICE constexpr static inline auto strideCounts() { - return make_vector((uint32_t)Traits::SplitKSegs, // WaveKSegs Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments + return make_vector(Traits::SplitKSegs, // WaveKSegs Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments } ROCWMMA_DEVICE constexpr static inline auto strides() { - return make_vector( - make_coord2d((uint32_t)Traits::SplitKStride_X, - (uint32_t)Traits::SplitKStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); + return make_vector(make_coord2d(Traits::SplitKStride_X, Traits::SplitKStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() + ROCWMMA_DEVICE static inline auto baseOffset() { - return make_coord2d((threadIdx.x * (uint32_t)Traits::DimPerThread) % BlockDim, - (threadIdx.x / MfmaDim * (uint32_t)Traits::KPerThread) - % BlockK); + return make_coord2d((threadIdx.x * Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * Traits::KPerThread) % BlockK); } // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { // Reference: // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); @@ -887,8 +863,7 @@ namespace rocwmma } // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) { int32_t cumVWOffsetX = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); @@ -927,14 +902,68 @@ namespace rocwmma } }; + template + struct OrthoTraits; + + template + struct OrthoTraits> + { + // Number of threads per wave + static constexpr uint32_t WaveSize = MatrixLayout::Traits::WaveSize; + + // Strides (swapped) + static constexpr uint32_t BlockDimStride_X = MatrixLayout::Traits::BlockDimStride_Y; + static constexpr uint32_t BlockDimStride_Y = MatrixLayout::Traits::BlockDimStride_X; + + static constexpr uint32_t BlockKStride_X = MatrixLayout::Traits::BlockKStride_Y; + static constexpr uint32_t BlockKStride_Y = MatrixLayout::Traits::BlockKStride_X; + + static constexpr uint32_t VWStride_X = MatrixLayout::Traits::VWStride_Y; + static constexpr uint32_t VWStride_Y = MatrixLayout::Traits::VWStride_X; + + // Stride space (same) + static constexpr uint32_t BlockDimSegs = MatrixLayout::Traits::BlockDimSegs; + static constexpr uint32_t BlockKSegs = MatrixLayout::Traits::BlockKSegs; + static constexpr uint32_t VWSegs = MatrixLayout::Traits::VWSegs; + }; + + template + struct OrthoTraits> + { + // Number of threads per wave + static constexpr uint32_t WaveSize = MatrixLayout::Traits::WaveSize; + + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = MatrixLayout::Traits::DimPerThread; + + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = MatrixLayout::Traits::KPerThread; + + // Number of elements that each thread is responsible for + static constexpr uint32_t ElementsPerThread = MatrixLayout::Traits::ElementsPerThread; + + // Swapped strides + static constexpr uint32_t SplitKStride_X = MatrixLayout::Traits::SplitKStride_Y; + static constexpr uint32_t SplitKStride_Y = MatrixLayout::Traits::SplitKStride_X; + + static constexpr uint32_t BlockKStride_X = MatrixLayout::Traits::BlockKStride_Y; + static constexpr uint32_t BlockKStride_Y = MatrixLayout::Traits::BlockKStride_X; + + static constexpr uint32_t VWStride_X = MatrixLayout::Traits::VWStride_Y; + static constexpr uint32_t VWStride_Y = MatrixLayout::Traits::VWStride_X; + + // Stride Space + static constexpr uint32_t SplitKSegs = MatrixLayout::Traits::SplitKSegs; + static constexpr uint32_t BlockKSegs = MatrixLayout::Traits::BlockKSegs; + static constexpr uint32_t VWSegs = MatrixLayout::Traits::VWSegs; + }; + template struct OrthoImpl { - // Matrix coord offsets - ROCWMMA_DEVICE static inline auto baseOffset() + struct Traits : public OrthoTraits { - return swap(MatrixLayout::baseOffset()); - } + }; ROCWMMA_DEVICE constexpr static inline auto strideCounts() { @@ -949,6 +978,11 @@ namespace rocwmma return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); } + ROCWMMA_DEVICE static inline auto baseOffset() + { + return swap(MatrixLayout::baseOffset()); + } + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) { return swap(MatrixLayout::incrementalOffset(iteration)); @@ -1005,4 +1039,116 @@ namespace rocwmma } // namespace rocwmma +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + ColOrthoVW const& matrix_layout) + { + return stream << "ColOrthoVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + ColInlineVW const& matrix_layout) + { + return stream << "ColInlineVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + RowOrthoVW const& matrix_layout) + { + return stream << "ColOrthoVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + RowInlineVW const& matrix_layout) + { + return stream << "ColInlineVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::ColOrthoInt const& + matrix_layout) + { + return stream << "ColOrthoInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::ColInlineInt const& + matrix_layout) + { + return stream << "ColInlineInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::RowOrthoInt const& + matrix_layout) + { + return stream << "ColOrthoInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::RowInlineInt const& + matrix_layout) + { + return stream << "ColInlineInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + #endif // ROCWMMA_MATRIX_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp index 3ef248fd..f79188f6 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp @@ -35,20 +35,6 @@ namespace rocwmma // Common helpers for supported traits namespace LayoutTraits_impl { - // Based on the current config, determine the compatibility of the mma dimension - constexpr static inline bool testSupportedMmaDim(uint32_t MmaDim) - { - return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && MmaDim == 16u) - || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && (MmaDim == 16u || MmaDim == 32u)); - } - - // VW can be changed from vw0 to vw1 as long as they have the same maxVW, and that maxVW - // is a multiple of both vw values - constexpr static inline bool testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) - { - return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); - } - // Reference regular layouts using MatrixLayout::ColInlineVW; using MatrixLayout::ColOrthoVW; @@ -61,7 +47,7 @@ namespace rocwmma using MatrixLayout::RowInlineInt; using MatrixLayout::RowOrthoInt; - // Start to build a basic set of meta-data classifiers. + // Build a basic set of meta-data classifiers. // We will be interested in knowing things about our matrix layouts: // - is_col_ortho // - is_row_ortho @@ -242,15 +228,35 @@ namespace rocwmma template constexpr static bool is_matrix_layout_v = is_matrix_layout::value; - // Next we can build a set of base trait accessors for the MatrixLayout. These - // will be reflective of the input template params of the MatrixLayout instance. + template + struct matrix_layout_classifier_traits + { + // Add associative traits + constexpr static bool is_col_ortho = is_col_ortho_v; + constexpr static bool is_col_inline = is_col_inline_v; + constexpr static bool is_row_ortho = is_row_ortho_v; + constexpr static bool is_row_inline = is_row_inline_v; + constexpr static bool is_interleaved = is_interleaved_v; + constexpr static bool is_matrix_layout = is_matrix_layout_v; + }; template - struct matrix_layout_base_traits + struct matrix_layout_derived_traits { + // Interface for params we want to derive from matrix layouts + constexpr static uint32_t BlockDim = 0u; + constexpr static uint32_t KDim = 0u; + using DataT = void; + constexpr static uint32_t VectorWidth = 0u; + constexpr static uint32_t MaxVectorWidth = 0u; + constexpr static uint32_t MmaDim = 0u; + constexpr static uint32_t SplitK = 0u; }; - // Represent non-interleaved MatrixLayout instances +#define matrix_layout \ + MatrixLayout + + // Combine internal layout traits with template params template class MatrixLayout> - struct matrix_layout_base_traits< - MatrixLayout, - enable_if_t> - && !is_interleaved_v>>> + struct matrix_layout_derived_traits< + matrix_layout, + enable_if_t && !is_interleaved_v>> + : public matrix_layout::Traits // Base traits { + // Common params derived from template params constexpr static uint32_t BlockDim = LayoutBlockDim; constexpr static uint32_t KDim = LayoutBlockK; using DataT = LayoutDataT; constexpr static uint32_t VectorWidth = LayoutVectorWidth; constexpr static uint32_t MaxVectorWidth = LayoutMaxVectorWidth; + constexpr static uint32_t MmaDim = LayoutBlockDim; // Effective MmaDim + constexpr static uint32_t SplitK = 0; // Unused }; +#undef matrix_layout + +#define matrix_layout \ + MatrixLayout + // Represent interleaved MatrixLayout instances template class MatrixLayout> - struct matrix_layout_base_traits< - MatrixLayout, - enable_if_t> - && is_interleaved_v>>> - { - constexpr static uint32_t BlockDim = LayoutBlockDim; - constexpr static uint32_t KDim = LayoutBlockK; - using DataT = LayoutDataT; - constexpr static uint32_t MmaDim = LayoutMmaDim; - constexpr static uint32_t SplitK = LayoutSplitK; + struct matrix_layout_derived_traits< + matrix_layout, + enable_if_t && is_interleaved_v>> + : public matrix_layout::Traits // Base traits + { + private: + // Wrapper to get fixed MaxVectorWidth / VectorWidth from layout + constexpr static inline uint32_t calcMaxVw() + { + if constexpr(is_col_inline_v || is_row_inline_v) + { + return matrix_layout::Traits::DimPerThread; + } + else if constexpr(is_col_ortho_v || is_row_ortho_v) + { + return matrix_layout::Traits::KPerThread; + } + else + { + return 0; + } + } + + public: + // Common params derived from template params + constexpr static uint32_t BlockDim = LayoutBlockDim; + constexpr static uint32_t KDim = LayoutBlockK; + using DataT = LayoutDataT; + constexpr static uint32_t VectorWidth = calcMaxVw(); + constexpr static uint32_t MaxVectorWidth = calcMaxVw(); + constexpr static uint32_t MmaDim = LayoutMmaDim; + constexpr static uint32_t SplitK = LayoutSplitK; }; +#undef matrix_layout + // Combine base instance traits with specific layout classifiers template - struct matrix_layout_traits : public matrix_layout_base_traits + struct matrix_layout_traits : public matrix_layout_derived_traits, + public matrix_layout_classifier_traits { - constexpr static bool is_col_ortho = is_col_ortho_v; - constexpr static bool is_col_inline = is_col_inline_v; - constexpr static bool is_row_ortho = is_row_ortho_v; - constexpr static bool is_row_inline = is_row_inline_v; - constexpr static bool is_interleaved = is_interleaved_v; - constexpr static bool is_matrix_layout = is_matrix_layout_v; }; // NOTE: MatrixLayout assumptions @@ -365,94 +378,106 @@ namespace rocwmma // Following the above traits, we can build more complicated traits such as // is_same, is_orthogonal and orthogonal_layout. - // When comparing one MatrixLayout to another, we need a way to check parameter compatibility. - template - struct is_compatible_matrix_params : public false_type - { - }; +// Tidy access to matrix layout traits. +#define traits_lhs matrix_layout_traits +#define traits_rhs matrix_layout_traits -// Keeps things a bit more tidy. Quick access to matrix layout traits. -#define mat_traits_lhs matrix_layout_traits -#define mat_traits_rhs matrix_layout_traits + // For a fixed maxVW, we can change VW of a matrix layout to any common divisor + ROCWMMA_HOST_DEVICE constexpr static inline bool + testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) + { + return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); + } - // Non-interleaved matrix layout compatibility requires: - // 1. Must have fixed: BlockDim, KDim, MaxVectorWidth - // 2. VectorWidths must satisfy criterion in testSupportedVW(). + // As a predicate to is_layout_same or is_layout_orthogonal, their matrix parameters must + // be compatible (see above table). template - struct is_compatible_matrix_params< - MatrixLayoutLhs, - MatrixLayoutRhs, - enable_if_t<(!mat_traits_lhs::is_interleaved && !mat_traits_rhs::is_interleaved)>> - : public integral_constant - { - }; + ROCWMMA_HOST_DEVICE constexpr static bool testCompatibleMatrixParams() + { + if constexpr(!traits_lhs::is_matrix_layout && !traits_rhs::is_matrix_layout) + { + return false; + } + else if constexpr(!traits_lhs::is_interleaved && !traits_rhs::is_interleaved) + { + // Non-interleaved matrix layout compatibility requires: + // 1. Fixed: BlockDim, KDim, MaxVectorWidth + // 2. VectorWidths must satisfy criterion in testSupportedVW(). + return (traits_lhs::BlockDim == traits_rhs::BlockDim) + && (traits_lhs::KDim == traits_rhs::KDim) + && (traits_lhs::MaxVectorWidth == traits_rhs::MaxVectorWidth) + && (testSupportedVW(traits_lhs::MaxVectorWidth, + traits_lhs::VectorWidth, + traits_rhs::VectorWidth)); + } + else if constexpr(traits_lhs::is_interleaved && traits_rhs::is_interleaved) + { + // Interleaved matrix layout compatibility requires: + // 1. Must have fixed BlockDim, BlockK, MmaDim, SplitK + // 2. MmaDim values must satisfy criterion in testSupportedMmaDim(). + return (traits_lhs::BlockDim == traits_rhs::BlockDim) + && (traits_lhs::KDim == traits_rhs::KDim) + && (traits_lhs::MmaDim == traits_rhs::MmaDim) + && (traits_lhs::SplitK == traits_rhs::SplitK) + && (traits_lhs::DimPerThread == traits_rhs::DimPerThread) + && (traits_lhs::KPerThread == traits_rhs::KPerThread); + } + else + { + return false; + } + } - // Interleaved matrix layout compatibility requires: - // 1. Must have fixed BlockDim, BlockK, MmaDim, SplitK - // 2. MmaDim values must satisfy criterion in testSupportedMmaDim(). + // Test for same layout template - struct is_compatible_matrix_params< - MatrixLayoutLhs, - MatrixLayoutRhs, - enable_if_t<(mat_traits_lhs::is_interleaved && mat_traits_rhs::is_interleaved)>> - : public integral_constant + ROCWMMA_HOST_DEVICE constexpr static bool testMatrixLayoutSame() { - }; + return ((traits_lhs::is_col_ortho && traits_rhs::is_col_ortho) + || (traits_lhs::is_row_ortho && traits_rhs::is_row_ortho) + || (traits_lhs::is_col_inline && traits_rhs::is_col_inline) + || (traits_lhs::is_row_inline && traits_rhs::is_row_inline)) + && testCompatibleMatrixParams(); + } - // Convenience evaluator + // Test for orthogonal layout template - constexpr static bool is_compatible_matrix_params_v - = is_compatible_matrix_params::value; + ROCWMMA_HOST_DEVICE constexpr static bool testMatrixLayoutOrthogonal() + { + return ((traits_lhs::is_col_ortho && traits_rhs::is_row_ortho) + || (traits_lhs::is_row_ortho && traits_rhs::is_col_ortho) + || (traits_lhs::is_col_inline && traits_rhs::is_row_inline) + || (traits_lhs::is_row_inline && traits_rhs::is_col_inline)) + && testCompatibleMatrixParams(); + } // Now to implement the interfaces for is_layout_same and is_layout_orthogonal, // with MatrixLayout types. - // Classifier to test same-ness, implements criterion #1 from above: + // Implement sameness classifier for matrix layouts template struct is_layout_same< MatrixLayoutLhs, MatrixLayoutRhs, - enable_if_t> - : public integral_constant< - bool, - ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_inline)) - && is_compatible_matrix_params_v> + enable_if_t> + : public integral_constant()> { }; - // Classifier to test orthogonality, implements criterion #2 from above: + // Implement orthogonality classifier for matrix layouts template struct is_layout_orthogonal< MatrixLayoutLhs, MatrixLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant< bool, - ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_inline)) - && is_compatible_matrix_params_v> + testMatrixLayoutOrthogonal()> { }; -#undef mat_traits_lhs -#undef mat_traits_rhs +#undef traits_lhs +#undef traits_rhs // Matrix space transpose guide: Swap rows / cols // VW stays consistent. @@ -536,9 +561,45 @@ namespace rocwmma { using type = ColInlineInt; }; + template + struct layout_traits>> + : public matrix_layout_traits + { + }; } // namespace LayoutTraits_impl } // namespace rocwmma +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::matrix_layout_traits const& traits) + { + using matrix_traits = decay_t; + + stream << "MatrixLayout Traits: " << MatrixLayout{} << std::endl; + stream << "is_col_ortho: " << matrix_traits::is_col_ortho << std::endl; + stream << "is_row_ortho: " << matrix_traits::is_row_ortho << std::endl; + stream << "is_col_inline: " << matrix_traits::is_col_inline << std::endl; + stream << "is_row_inline: " << matrix_traits::is_row_inline << std::endl; + stream << "is_interleaved: " << matrix_traits::is_interleaved << std::endl; + stream << "is_matrix_layout: " << matrix_traits::is_matrix_layout << std::endl; + stream << "BlockDim: " << matrix_traits::BlockDim << std::endl; + stream << "KDim: " << matrix_traits::KDim << std::endl; + stream << "MmaDim: " << matrix_traits::MmaDim << std::endl; + stream << "SplitK: " << matrix_traits::SplitK << std::endl; + stream << "VectorWidth: " << matrix_traits::VectorWidth << std::endl; + stream << "MaxVectorWidth: " << matrix_traits::MaxVectorWidth << std::endl; + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + #endif // ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 69752ce6..7411f6ce 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -49,18 +49,19 @@ namespace rocwmma { }; - template - struct is_register_layout> : public is_matrix_layout + template + struct is_register_layout> + : public is_matrix_layout { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; @@ -69,8 +70,8 @@ namespace rocwmma { }; - template - struct is_storage> : public is_matrix_layout + template + struct is_storage> : public is_matrix_layout { }; @@ -79,8 +80,8 @@ namespace rocwmma { }; - template - struct is_mma_input> : public true_type + template + struct is_mma_input> : public true_type { }; @@ -89,8 +90,8 @@ namespace rocwmma { }; - template - struct is_mma_acc> : public true_type + template + struct is_mma_acc> : public true_type { }; @@ -108,41 +109,148 @@ namespace rocwmma template constexpr inline static bool is_mma_acc_v = is_mma_acc::value; - // Next we can build a set of base trait accessors for the RegisterLayout. These - // will be reflective of the input template params of the RegisterLayout instance. template - struct register_layout_base_traits + struct register_layout_classifier_traits + { + constexpr static bool is_register_layout = is_register_layout_v; + constexpr static bool is_storage = is_storage_v; + constexpr static bool is_mma_input = is_mma_input_v; + constexpr static bool is_mma_acc = is_mma_acc_v; + }; + + template + struct register_layout_traits; + + // Test the consistency of matrix layouts under different data layouts. + // RegisterLayouts are consistent for both data layouts if we restrict + // VectorWidth to 1 in the opposite data layout grain. + // This applies to all matrix layouts. + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutIdentity() + { + using traits = register_layout_traits; + + if constexpr(traits::is_col_inline) + { + return (traits::is_col_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_row_inline) + { + return (traits::is_row_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_col_ortho) + { + return (traits::is_row_major || traits::VectorWidth == 1u); + } + else if constexpr(traits::is_row_ortho) + { + return (traits::is_col_major || traits::VectorWidth == 1u); + } + + return false; + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutAos() + { + using traits = register_layout_traits; + + // AOS is a strict register layout where contiguous elements + // capture contiguous BlockDim elements and must be consistent. + return (traits::is_col_inline || traits::is_row_inline) + && testStorageLayoutIdentity(); + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutSoa() + { + using traits = register_layout_traits; + + // SOA is a strict register layout where contiguous elements + // capture contiguous BlockK elements and must be consistent. + return (traits::is_col_ortho || traits::is_row_ortho) + && testStorageLayoutIdentity(); + } + + // Based on the current config, mma dimensions supported + template + ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaDim() + { + using traits = register_layout_traits; + return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && traits::MmaDim == 16u) + || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && (traits::MmaDim == 16u || traits::MmaDim == 32u)); + } + + template + struct register_layout_derived_traits { }; - template - struct register_layout_base_traits> + template + struct register_layout_derived_traits> + : public matrix_layout_traits, + public data_layout_traits { using MatrixLayout = MatrixLayoutInternal; + using DataLayout = DataLayoutInternal; + + constexpr static bool is_aos_format + = testStorageLayoutAos>(); + constexpr static bool is_soa_format + = testStorageLayoutSoa>(); + constexpr static bool is_valid + = testStorageLayoutIdentity>(); + + constexpr static RegisterLayout::Format Format + = is_aos_format ? RegisterLayout::Format::AOS + : (is_soa_format ? RegisterLayout::Format::SOA + : RegisterLayout::Format::None); }; - template - struct register_layout_base_traits> + template + struct register_layout_derived_traits> + : public matrix_layout_traits, public data_layout_traits { - constexpr static uint32_t MmaDim = LayoutMmaDim; - using MatrixLayout = void; + using MatrixLayout = void; + using DataLayout = void; + + // Overrides + constexpr static bool is_interleaved = LayoutIsInterleaved; + constexpr static uint32_t MmaDim = LayoutMmaDim; + + constexpr static bool is_aos_format = (Fmt == RegisterLayout::Format::AOS); + constexpr static bool is_soa_format = (Fmt == RegisterLayout::Format::SOA); + constexpr static bool is_valid + = testSupportedMmaDim>(); + + constexpr static RegisterLayout::Format Format = Fmt; }; - template - struct register_layout_base_traits> + template + struct register_layout_derived_traits> + : public matrix_layout_traits, public data_layout_traits { - constexpr static uint32_t MmaDim = LayoutMmaDim; - using MatrixLayout = void; + using MatrixLayout = void; + using DataLayout = void; + + // Overrides + constexpr static bool is_interleaved = LayoutIsInterleaved; + constexpr static uint32_t MmaDim = LayoutMmaDim; + + constexpr static bool is_aos_format = (Fmt == RegisterLayout::Format::AOS); + constexpr static bool is_soa_format = (Fmt == RegisterLayout::Format::SOA); + constexpr static bool is_valid + = testSupportedMmaDim>(); + + constexpr static RegisterLayout::Format Format = Fmt; }; // Combine base instance traits with specific layout classifiers template - struct register_layout_traits : public register_layout_base_traits + struct register_layout_traits : public register_layout_derived_traits, + public register_layout_classifier_traits { - constexpr static bool is_register_layout = is_register_layout_v; - constexpr static bool is_storage = is_storage_v; - constexpr static bool is_mma_input = is_mma_input_v; - constexpr static bool is_mma_acc = is_mma_acc_v; }; // NOTE: RegisterLayout assumptions @@ -171,6 +279,10 @@ namespace rocwmma // | MmaInput | Storage | BlockDim == MmaDim | // | Storage | MmaInput | BlockDim == MmaDim | // | MmaInput | Storage | BlockDim == MmaDim | + // | Storage | MmaAcc | BlockDim == MmaDim, MaxVW = 4* | + // | MmaAcc | Storage | BlockDim == MmaDim, MaxVW = 4* | + // | Storage | MmaAcc | BlockDim == MmaDim, MaxVW = 4* | + // | MmaAcc | Storage | BlockDim == MmaDim, MaxVW = 4* | * = arch dependent // | ------------------------------------------------------------------------------- | // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | @@ -193,7 +305,7 @@ namespace rocwmma // MmaDim, MmaInput is also orthogonal to Storage layouts. // _______________________________________________________________________________ // | MatrixLayoutLhs | MatrixLayoutRhs | Required Fixed Params | - // | | (Transposed) | | + // | | (Orthogonal) | | // | ----------------------------------------------------------------------------- | // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | @@ -207,6 +319,10 @@ namespace rocwmma // | MmaInput | Storage | BlockDim == MmaDim | // | Storage | MmaInput | BlockDim == MmaDim | // | MmaInput | Storage | BlockDim == MmaDim | + // | Storage | MmaAcc | BlockDim == MmaDim | + // | MmaAcc | Storage | BlockDim == MmaDim | + // | Storage | MmaInput | BlockDim == MmaDim | + // | MmaInput | Storage | BlockDim == MmaDim | // | ----------------------------------------------------------------------------- | // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | @@ -220,129 +336,304 @@ namespace rocwmma // | MmaInput | Storage| MmaDim | // | Storage | MmaInput | MmaDim | // | MmaInput | Storage| MmaDim | + // | Storage | MmaInput | MmaDim | + // | MmaInput | Storage| MmaDim | // | ----------------------------------------------------------------------------- | // Keeps things a bit more tidy. Quick access to register layout traits. -#define reg_traits_lhs register_layout_traits -#define reg_traits_rhs register_layout_traits - -// Quick access to matrix layout traits, that are embedded in the register layout traits. -#define mat_traits_lhs matrix_layout_traits -#define mat_traits_rhs matrix_layout_traits - - template - struct is_compatible_register_params; +#define traits_lhs register_layout_traits +#define traits_rhs register_layout_traits +#define traits register_layout_traits - // Compatibility for Storage is a passthrough to MatrixLayout compatibility. - template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t> - : public is_compatible_matrix_params + template + ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaAccMaxVW() { - }; + // Test the MaxVectorWidth of storage layouts for MMA requirements. + if constexpr(traits::is_storage) + { + // Interleaved storage layouts not compatible with MmaAcc + if constexpr(traits::is_interleaved) + { + return false; + } + else if constexpr((bool)ROCWMMA_ARCH_GFX12) + { + return traits::MaxVectorWidth == 8u; + } + else if constexpr((bool)ROCWMMA_ARCH_GFX11 + || is_same::value) + { + return traits::MaxVectorWidth == 1u; + } + else // General case + { + return traits::MaxVectorWidth == 4u; + } + } + + // Mma input not compatible with acc + return traits::is_mma_acc; + } + + // Test the consistency of matrix layouts under different data layouts. + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutIdentity() + { + if constexpr(traits::is_storage) + { + // RegisterLayouts are consistent for both data layouts if we restrict + // VectorWidth to 1 in the opposite data layout grain. + if constexpr(traits::is_col_inline) + { + return (traits::is_col_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_row_inline) + { + return (traits::is_row_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_col_ortho) + { + return (traits::is_row_major || traits::VectorWidth == 1u); + } + else if constexpr(traits::is_row_ortho) + { + return (traits::is_col_major || traits::VectorWidth == 1u); + } + } + + // Mma input and acc are symbolic register layouts. + // Both are consistent in either row/col major data layouts. + return traits::is_mma_input || traits::is_mma_acc; + } - // Compatibility for MmaInputs - template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t> - : public integral_constant + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutAos() { - }; + // AOS is a strict register layout where contiguous elements + // capture contiguous BlockDim elements and must be consistent. + if constexpr(traits::is_storage) + { + return (traits::is_col_inline || traits::is_row_inline) + && testRegisterLayoutIdentity(); + } + else + { + // None of the MMA inputs are AOS + return !traits::is_mma_input && !traits::is_mma_acc; + } + } - // Non-interleaved register layout compatibility with MmaInput requires: - // 1. Inner matrix layout and mma input layout must have same: BlockDim / MmaDim - // 2. MmaDim must satisfy criterion in testSupportedMmaDim(). - template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t<(reg_traits_lhs::is_storage && !mat_traits_lhs::is_interleaved) - && reg_traits_rhs::is_mma_input>> - : public integral_constant + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutSoa() { - }; + // SOA is a strict register layout where contiguous elements + // capture contiguous BlockK elements and must be consistent. + if constexpr(traits::is_storage) + { + return (traits::is_col_ortho || traits::is_row_ortho) + && testRegisterLayoutIdentity(); + } + else + { + // Interleaved acc is not SOA + return traits::is_mma_input || (traits::is_mma_acc && !traits::is_interleaved); + } + } - template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t> - : public integral_constant + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutMmaInput() { - }; + // MMA inputs must be compatible with MMA size support + if constexpr(traits::is_storage) + { + return traits::is_soa_format && testSupportedMmaDim(); + } + else + { + return traits::is_mma_input && testSupportedMmaDim(); + } + } - // Interleaved register layout compatibility with MmaInput requires: - // 1. Inner matrix layout and mma input layout must have same: MmaDim - // 2. MmaDim must satisfy criterion in testSupportedMmaDim(). + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutMmaAcc() + { + // MMA acc must be compatible with MMA dim and MaxVW + if constexpr(traits::is_storage && !traits::is_interleaved) + { + return testRegisterLayoutSoa() + && testSupportedMmaDim() + && testSupportedMmaAccMaxVW(); + } + else + { + // Interleaved storage layouts and MmaInput are not compatible + // with MMA acc format + return traits::is_mma_acc && testSupportedMmaDim(); + } + } + + // As a predicate to is_layout_same or is_layout_orthogonal, their register parameters must + // be compatible (see above table). template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t<(reg_traits_lhs::is_storage && mat_traits_lhs::is_interleaved) - && reg_traits_rhs::is_mma_input>> - : public integral_constant + ROCWMMA_HOST_DEVICE constexpr static bool testCompatibleRegisterParams() { - }; + // Basic test: + // Matching MmaDim, interleaving and validity + constexpr bool BaseTest = (traits_lhs::MmaDim == traits_rhs::MmaDim) + && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) + && (traits_lhs::is_valid == traits_rhs::is_valid); + + // Storage <-> Storage must check Matrix compatibility + if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) + { + return testCompatibleMatrixParams() + && BaseTest; + } + // MmaInput <-> MmaInput + // MmaAcc <-> MmaAcc + // Storage <-> MmaInput + else if constexpr((traits_lhs::is_mma_input && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc) + || (traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage)) + { + return BaseTest; + } + + // Storage <-> MmaAcc must also check MaxVW + else if constexpr((traits_lhs::is_storage && traits_rhs::is_mma_acc) + || (traits_lhs::is_mma_acc && traits_rhs::is_storage)) + { + using test_traits = conditional_t; + + constexpr uint32_t ExpectedAccMaxVW + = ((bool)ROCWMMA_ARCH_GFX12) ? 8u + : ((bool)ROCWMMA_ARCH_GFX11 + || is_same::value) + ? 1u + : 4u; + + constexpr bool TestMmaAccMaxVW = (ExpectedAccMaxVW == test_traits::MaxVectorWidth); + + return TestMmaAccMaxVW && BaseTest; + } + // MmaInput <-> MmaAcc not compatible + else + { + return false; + } + } template - struct is_compatible_register_params< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t> - : public integral_constant + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutSame() { - }; + // Required compatibility + constexpr bool TestCompatibleParams + = testCompatibleRegisterParams(); + + // Test both register layouts in same format + constexpr bool TestFormatMatch = (traits_lhs::Format == traits_rhs::Format); + + if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) + { + // Exact match for same matrix and data layouts + constexpr bool TestExactMatch + = testMatrixLayoutSame() + && testDataLayoutSame(); + + // Orthogonal matrix layout and orthogonal data layout (implicit transpose) + constexpr bool TestImplicitTranspose + = testMatrixLayoutOrthogonal() + && testDataLayoutOrthogonal(); + + // Special case: interleaved VW dimension + // Check matching dims and if either one is == 1u + if constexpr(traits_lhs::is_interleaved && traits_rhs::is_interleaved) + { + constexpr bool TestIdentityQuirks + = (traits_lhs::DimPerThread == traits_rhs::DimPerThread) + && (traits_lhs::KPerThread == traits_rhs::KPerThread) + && ((traits_lhs::DimPerThread == 1u) || (traits_lhs::KPerThread == 1u)); + + return (TestExactMatch || TestImplicitTranspose || TestFormatMatch + || TestIdentityQuirks) + && TestCompatibleParams; + } + + return (TestExactMatch || TestImplicitTranspose || TestFormatMatch) + && TestCompatibleParams; + } + else // Mix of storage, MmaInput, MmaAcc + { + // Test both sides for MmaInput compatibility + constexpr bool TestMmaInputMatch + = testRegisterLayoutMmaInput() + && testRegisterLayoutMmaInput() && TestCompatibleParams; + + // Test both sides for MmaAcc compatibility + constexpr bool TestMmaAccMatch = testRegisterLayoutMmaAcc() + && testRegisterLayoutMmaAcc() + && TestCompatibleParams; + + return (TestMmaInputMatch || TestMmaAccMatch || TestFormatMatch) + && TestCompatibleParams; + } + } - // Convenience evaluator template - constexpr static inline bool is_compatible_register_params_v - = is_compatible_register_params::value; + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutOrthogonal() + { + // Required not same and compatible params + constexpr bool TestNotSame + = !testRegisterLayoutSame(); + constexpr bool TestCompatibleParams + = testCompatibleRegisterParams(); + + // Path between valid AOS and SOA formats + constexpr bool TestOpposingFormat + = (traits_lhs::is_soa_format && traits_rhs::is_aos_format) + || (traits_lhs::is_aos_format && traits_rhs::is_soa_format); + + // (testRegisterLayoutAos() && testRegisterLayoutSoa()) + // || (testRegisterLayoutSoa() && testRegisterLayoutAos()); + + if constexpr((traits_lhs::is_interleaved && traits_rhs::is_interleaved) + && (traits_lhs::is_mma_acc || traits_rhs::is_mma_acc)) + { + using RegisterLayoutMmaAcc + = conditional_t; + using RegisterLayoutOther + = conditional_t; + + // Special case: path between valid interleaved AOS/SOA and MmaAcc register layouts exists. + constexpr bool TestStorageToAcc + = testRegisterLayoutMmaAcc() + && (testRegisterLayoutAos() + || testRegisterLayoutSoa()); + + return (TestOpposingFormat || TestStorageToAcc) && TestNotSame + && TestCompatibleParams; + } + else + { + return TestOpposingFormat && TestNotSame && TestCompatibleParams; + } + } // Checks if both RegisterLayout storages are the same with compatible params template struct is_layout_same< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant< bool, - // Check for same in-register layouts - ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_inline) - - // Check for in-register implicit transposes. These have the same register layouts, - // but swap meaning for rows / cols. - || (mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_inline) - - // Check mma input sameness - || (reg_traits_lhs::is_mma_input && reg_traits_rhs::is_mma_input) - || (mat_traits_lhs::is_col_ortho && reg_traits_rhs::is_mma_input) - || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_row_ortho && reg_traits_rhs::is_mma_input) - || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_row_ortho)) - && is_compatible_register_params_v> + testRegisterLayoutSame()> { }; @@ -351,44 +642,64 @@ namespace rocwmma struct is_layout_orthogonal< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t> + enable_if_t> : public integral_constant< bool, - // Orthogonality in same orientation (e.g., col / row) - ((mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_col_inline) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_row_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_row_ortho) - - // Orthogonality in opposite orientation (e.g., col vs row) - || (mat_traits_lhs::is_col_ortho && mat_traits_rhs::is_row_inline) - || (mat_traits_lhs::is_row_inline && mat_traits_rhs::is_col_ortho) - || (mat_traits_lhs::is_col_inline && mat_traits_rhs::is_row_ortho) - || (mat_traits_lhs::is_row_ortho && mat_traits_rhs::is_col_inline) - - // Check mma input compatibility - || (mat_traits_lhs::is_col_inline && reg_traits_rhs::is_mma_input) - || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_col_inline) - || (mat_traits_lhs::is_row_inline && reg_traits_rhs::is_mma_input) - || (reg_traits_lhs::is_mma_input && mat_traits_rhs::is_row_inline)) - && is_compatible_register_params_v> + testRegisterLayoutOrthogonal()> { }; -#undef reg_traits_lhs -#undef reg_traits_rhs -#undef mat_traits_lhs -#undef mat_traits_rhs +#undef traits_lhs +#undef traits_rhs +#undef traits // Use generic MatrixLayout orthogonality rules to guide the register layout transpose suggestion - template - struct orthogonal_layout> + // TODO: fix + template + struct orthogonal_layout> + { + using type = Storage::type, + typename orthogonal_layout::type>; + }; + + template + struct layout_traits>> + : public register_layout_traits { - using type = Storage::type>; }; } // namespace LayoutTraits_impl } // namespace rocwmma +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::register_layout_traits const& traits) + { + using register_traits = decay_t; + + stream << "RegisterLayout Traits: " << RegisterLayout{} << std::endl; + stream << "is_register_layout: " << traits.is_register_layout << std::endl; + stream << "is_storage: " << traits.is_storage << std::endl; + stream << "is_mma_input: " << traits.is_mma_input << std::endl; + stream << "is_mma_acc: " << traits.is_mma_acc << std::endl; + stream << "is_interleaved: " << traits.is_interleaved << std::endl; + stream << "MmaDim: " << traits.MmaDim << std::endl; + stream << "is_aos_format: " << traits.is_aos_format << std::endl; + stream << "is_soa_format: " << traits.is_soa_format << std::endl; + stream << "is_valid: " << traits.is_valid << std::endl; + stream << "Format: " << traits.Format << std::endl; + + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + #endif // ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index d2beada3..d51f5076 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -65,61 +65,87 @@ namespace rocwmma } }; - // AOS -> SOA transform (non-interleaved) requirements: - // - Lhs is *Inline - // - layouts are not interleaved - // - layouts are orthogonal + // Non-interleaved orthogonal transforms + // template struct register_layout_transform< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t<(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) - && !mat_traits_lhs::is_interleaved + enable_if_t<(!mat_traits_lhs::is_interleaved || !mat_traits_rhs::is_interleaved) && is_layout_orthogonal_v>> { template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) { - return Transforms::AosToSoa::exec(forward(v)); - } - }; + // Orthogonality promises: + // BlockDim, KDim, MaxVW match on lhs and rhs - // SOA -> AOS transform (non-interleaved) requirements: - // - Lhs is *Ortho - // - layouts are not interleaved - // - layouts are orthogonal - template - struct register_layout_transform< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t<(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) - && !mat_traits_lhs::is_interleaved - && is_layout_orthogonal_v>> - { - template - ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) - { - return Transforms::SoaToAos::exec(forward(v)); + // Inline to ortho layout (AOS -> SOA) + if constexpr(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) + { + return Transforms:: + AosToSoa::exec( + forward(v)); + } + // Ortho to inline layout (SOA -> AOS) + else if constexpr(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) + { + return Transforms:: + SoaToAos::exec( + forward(v)); + } + // MmaInput (ortho) to inline layout (SOA -> AOS) + else if constexpr(reg_traits_lhs::is_mma_input) + { + return Transforms:: + SoaToAos::exec( + forward(v)); + } + else + { + static_assert(0, "Shouldn't get here"); + return v; + } } }; - // Interleaved layout transform: - // - layouts are interleaved - // - layouts are orthogonal + // Interleaved orthogonal transforms template struct register_layout_transform< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t>> { template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) { - // TODO: replace with DimPerThread for interleaved. - return interleave(forward(v)); + // Orthogonality promises: + // BlockDim, KDim, MmaDim match on lhs and rhs + + // Inline to ortho layout (AOS -> SOA) + if constexpr(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) + { + // Leading dim VW + return interleave<1u, mat_traits_lhs::DimPerThread>(forward(v)); + } + // Ortho to inline layout (SOA -> AOS) + else if constexpr(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) + { + // KDim VW + return interleave<1u, mat_traits_lhs::KPerThread>(forward(v)); + } + // MmaInput (ortho) to inline + else if constexpr(reg_traits_lhs::is_mma_input) + { + // Leading dim VW + return interleave<1u, mat_traits_rhs::DimPerThread>(forward(v)); + } + else + { + static_assert(0, "Shouldn't get here"); + return v; + } } }; From 64da22160e09f1eb23ab9ad8f4031fa47728e664 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 11 Nov 2024 03:58:27 +0000 Subject: [PATCH 10/36] Introduce register formats and refactor is_layout_same and is_layout_orthogonal logic --- .../rocwmma/internal/layout/layout.hpp | 16 +- .../layout/register_layout_traits_impl.hpp | 517 ++++++------------ 2 files changed, 179 insertions(+), 354 deletions(-) diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index d85b0c77..83f07875 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -147,9 +147,11 @@ namespace rocwmma // Format for data locality enum struct Format : uint32_t { - SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] - AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] - None = 2u, + SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] + AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] + ACC_INT_A_MAJOR = 2u, // Interleaved Mma 'A' major order + ACC_INT_B_MAJOR = 3u, // Interleaved Mma 'B' major order + Invalid, // Invalid register format }; // A mnemonic used to describe the register layout is suitable for input/output @@ -167,7 +169,7 @@ namespace rocwmma // A mnemonic used to describe the register layout is suitable for mma input for accumulator input/output template + Format Fmt = Interleaved ? Format::ACC_INT_A_MAJOR : Format::SOA> struct MmaAcc { }; @@ -211,7 +213,11 @@ namespace std { return stream << (fmt == rocwmma::RegisterLayout::Format::AOS ? "AOS" : (fmt == rocwmma::RegisterLayout::Format::SOA) ? "SOA" - : "NONE"); + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_A_MAJOR) + ? "ACC_INT_A_MAJOR" + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_B_MAJOR) + ? "ACC_INT_B_MAJOR" + : "INVALID"); } template diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 7411f6ce..6b066e40 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -121,7 +121,6 @@ namespace rocwmma template struct register_layout_traits; - // Test the consistency of matrix layouts under different data layouts. // RegisterLayouts are consistent for both data layouts if we restrict // VectorWidth to 1 in the opposite data layout grain. // This applies to all matrix layouts. @@ -129,7 +128,6 @@ namespace rocwmma ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutIdentity() { using traits = register_layout_traits; - if constexpr(traits::is_col_inline) { return (traits::is_col_major || traits::VectorWidth == 1); @@ -150,36 +148,66 @@ namespace rocwmma return false; } + // AOS is a strict register layout where thread VW is inline + // with contiguous BlockDim elements. + // To be valid, the layout must be consistent across row_major + // and col_major data layouts. template ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutAos() { using traits = register_layout_traits; - - // AOS is a strict register layout where contiguous elements - // capture contiguous BlockDim elements and must be consistent. - return (traits::is_col_inline || traits::is_row_inline) - && testStorageLayoutIdentity(); + return (traits::is_col_inline || traits::is_row_inline); } + // SOA is a strict register layout where thread VW is inline + // with contiguous BlockK elements, orthogonal to BlockDim. + // To be valid, the layout must be consistent across row_major + // and col_major data layouts. template ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutSoa() { using traits = register_layout_traits; - - // SOA is a strict register layout where contiguous elements - // capture contiguous BlockK elements and must be consistent. - return (traits::is_col_ortho || traits::is_row_ortho) - && testStorageLayoutIdentity(); + return (traits::is_col_ortho || traits::is_row_ortho); } - // Based on the current config, mma dimensions supported + // Based on the current architecture, which mma dimensions supported template ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaDim() { using traits = register_layout_traits; return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && traits::MmaDim == 16u) - || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED - && (traits::MmaDim == 16u || traits::MmaDim == 32u)); + || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && traits::MmaDim == 32u); + } + + // Based on the current architecture, which register layout formats currently supported + template + ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedFormat() + { + using traits = register_layout_traits; + using rocwmma::RegisterLayout::Format; + if constexpr(traits::is_mma_input) + { + return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + } + else if constexpr(traits::is_mma_acc) + { + if constexpr(traits::is_interleaved) + { + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::ACC_INT_A_MAJOR) + || (traits::Format == Format::ACC_INT_B_MAJOR); + } + else + { + return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + } + } + else + { + return traits::is_storage + && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) + || (traits::Format == Format::Invalid)); + } } template @@ -195,17 +223,17 @@ namespace rocwmma using MatrixLayout = MatrixLayoutInternal; using DataLayout = DataLayoutInternal; - constexpr static bool is_aos_format - = testStorageLayoutAos>(); - constexpr static bool is_soa_format - = testStorageLayoutSoa>(); - constexpr static bool is_valid - = testStorageLayoutIdentity>(); - + // Determine the register format of the current storage layout constexpr static RegisterLayout::Format Format - = is_aos_format ? RegisterLayout::Format::AOS - : (is_soa_format ? RegisterLayout::Format::SOA - : RegisterLayout::Format::None); + = testStorageLayoutAos>() + ? RegisterLayout::Format::AOS + : (testStorageLayoutSoa>() + ? RegisterLayout::Format::SOA + : RegisterLayout::Format::Invalid); + + constexpr static bool is_valid + = testStorageLayoutIdentity>() + && testSupportedFormat>(); }; template @@ -219,12 +247,12 @@ namespace rocwmma constexpr static bool is_interleaved = LayoutIsInterleaved; constexpr static uint32_t MmaDim = LayoutMmaDim; - constexpr static bool is_aos_format = (Fmt == RegisterLayout::Format::AOS); - constexpr static bool is_soa_format = (Fmt == RegisterLayout::Format::SOA); - constexpr static bool is_valid - = testSupportedMmaDim>(); - + // Template param driven format constexpr static RegisterLayout::Format Format = Fmt; + + constexpr static bool is_valid + = testSupportedMmaDim>() + && testSupportedFormat>(); }; template @@ -238,285 +266,108 @@ namespace rocwmma constexpr static bool is_interleaved = LayoutIsInterleaved; constexpr static uint32_t MmaDim = LayoutMmaDim; - constexpr static bool is_aos_format = (Fmt == RegisterLayout::Format::AOS); - constexpr static bool is_soa_format = (Fmt == RegisterLayout::Format::SOA); - constexpr static bool is_valid - = testSupportedMmaDim>(); - + // Template param driven format constexpr static RegisterLayout::Format Format = Fmt; + + constexpr static bool is_valid + = testSupportedMmaDim>() + && testSupportedFormat>(); }; // Combine base instance traits with specific layout classifiers template - struct register_layout_traits : public register_layout_derived_traits, - public register_layout_classifier_traits + struct register_layout_traits : public register_layout_classifier_traits, + public register_layout_derived_traits + { }; - // NOTE: RegisterLayout assumptions + // NOTE: RegisterLayout comparison assumptions // When determining RegisterLayout traits, there are several strong assumptions. + // Register layouts are assigned Formats, based on their given matrix and data layouts. // 1. Regarding same-ness: - // - Storage match if MatrixLayouts match, given fixed params. - // - Storage match if MatrixLayouts are either both *Ortho or both *Inline - // orientations. Register thread mapping is the same while swapping the underlying - // meaning of rows for cols (e.g., implicit transpose). - // - Storage<*Ortho> layouts are suitable MmaInputs while Storage<*Inline> layouts are not. - // Given appropriate MmaDim, it is assumed MmaInput layouts are mapped to mma hardware - // requirements. - // _________________________________________________________________________________ - // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | - // | | (Same) | Required Fixed Params | - // | ------------------------------------------------------------------------------- | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | MmaInput | BlockDim == MmaDim | - // | MmaInput | Storage | BlockDim == MmaDim | - // | Storage | MmaInput | BlockDim == MmaDim | - // | MmaInput | Storage | BlockDim == MmaDim | - // | Storage | MmaAcc | BlockDim == MmaDim, MaxVW = 4* | - // | MmaAcc | Storage | BlockDim == MmaDim, MaxVW = 4* | - // | Storage | MmaAcc | BlockDim == MmaDim, MaxVW = 4* | - // | MmaAcc | Storage | BlockDim == MmaDim, MaxVW = 4* | * = arch dependent - // | ------------------------------------------------------------------------------- | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | MmaInput | MmaDim | - // | MmaInput | Storage | MmaDim | - // | Storage | MmaInput | MmaDim | - // | MmaInput | Storage | MmaDim | - // | ------------------------------------------------------------------------------- | + // - Register formats match, if tested for matching register layout traits: + // MmaDim, is_interleaved and is_valid. + // - Register layouts match if register formats match, and there is congruency between + // Storage, MmaInput and MmaAcc types. + // - Congruency between Storage, MmaInput and MmaAcc types is partly defined by how + // MmaInput and MmaAcc register format template parameters are set for the Mma workflow, + // and partly by architecture (e.g., MmaAcc layout VW per block is fixed). // // 2. Regarding orthogonality: - // - Storages are considered orthogonal if one MatrixLayout is an - // *Ortho layout and the other is an *Inline layout, or vice versa. - // - Since MmaInput layouts are same as Storage layouts with appropriate - // MmaDim, MmaInput is also orthogonal to Storage layouts. - // _______________________________________________________________________________ - // | MatrixLayoutLhs | MatrixLayoutRhs | Required Fixed Params | - // | | (Orthogonal) | | - // | ----------------------------------------------------------------------------- | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | Storage | BlockDim, KDim, MaxVectorWidth | - // | Storage | MmaInput | BlockDim == MmaDim | - // | MmaInput | Storage | BlockDim == MmaDim | - // | Storage | MmaInput | BlockDim == MmaDim | - // | MmaInput | Storage | BlockDim == MmaDim | - // | Storage | MmaAcc | BlockDim == MmaDim | - // | MmaAcc | Storage | BlockDim == MmaDim | - // | Storage | MmaInput | BlockDim == MmaDim | - // | MmaInput | Storage | BlockDim == MmaDim | - // | ----------------------------------------------------------------------------- | - // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage| BlockDim, KDim, MmaDim, SplitK | - // | Storage | Storage | BlockDim, KDim, MmaDim, SplitK | - // | Storage | MmaInput | MmaDim | - // | MmaInput | Storage| MmaDim | - // | Storage | MmaInput | MmaDim | - // | MmaInput | Storage| MmaDim | - // | Storage | MmaInput | MmaDim | - // | MmaInput | Storage| MmaDim | - // | ----------------------------------------------------------------------------- | + // - Format orthogonality is defined as having an in-register transition from one distinct + // format to another. + // E.g,. AOS <-> SOA, SOA <-> ACC_INT_A_MAJOR, SOA <-> ACC_INT_B_MAJOR, + // AOS <-> ACC_INT_A_MAJOR or AOS <-> ACC_INT_B_MAJOR. + // These require matching MmaDim, is_interleaved and is_valid traits. // Keeps things a bit more tidy. Quick access to register layout traits. #define traits_lhs register_layout_traits #define traits_rhs register_layout_traits #define traits register_layout_traits - template - ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaAccMaxVW() - { - // Test the MaxVectorWidth of storage layouts for MMA requirements. - if constexpr(traits::is_storage) - { - // Interleaved storage layouts not compatible with MmaAcc - if constexpr(traits::is_interleaved) - { - return false; - } - else if constexpr((bool)ROCWMMA_ARCH_GFX12) - { - return traits::MaxVectorWidth == 8u; - } - else if constexpr((bool)ROCWMMA_ARCH_GFX11 - || is_same::value) - { - return traits::MaxVectorWidth == 1u; - } - else // General case - { - return traits::MaxVectorWidth == 4u; - } - } - - // Mma input not compatible with acc - return traits::is_mma_acc; - } - - // Test the consistency of matrix layouts under different data layouts. - template - ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutIdentity() - { - if constexpr(traits::is_storage) - { - // RegisterLayouts are consistent for both data layouts if we restrict - // VectorWidth to 1 in the opposite data layout grain. - if constexpr(traits::is_col_inline) - { - return (traits::is_col_major || traits::VectorWidth == 1); - } - else if constexpr(traits::is_row_inline) - { - return (traits::is_row_major || traits::VectorWidth == 1); - } - else if constexpr(traits::is_col_ortho) - { - return (traits::is_row_major || traits::VectorWidth == 1u); - } - else if constexpr(traits::is_row_ortho) - { - return (traits::is_col_major || traits::VectorWidth == 1u); - } - } - - // Mma input and acc are symbolic register layouts. - // Both are consistent in either row/col major data layouts. - return traits::is_mma_input || traits::is_mma_acc; - } - - template - ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutAos() - { - // AOS is a strict register layout where contiguous elements - // capture contiguous BlockDim elements and must be consistent. - if constexpr(traits::is_storage) - { - return (traits::is_col_inline || traits::is_row_inline) - && testRegisterLayoutIdentity(); - } - else - { - // None of the MMA inputs are AOS - return !traits::is_mma_input && !traits::is_mma_acc; - } - } - - template - ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutSoa() - { - // SOA is a strict register layout where contiguous elements - // capture contiguous BlockK elements and must be consistent. - if constexpr(traits::is_storage) - { - return (traits::is_col_ortho || traits::is_row_ortho) - && testRegisterLayoutIdentity(); - } - else - { - // Interleaved acc is not SOA - return traits::is_mma_input || (traits::is_mma_acc && !traits::is_interleaved); - } - } - - template - ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutMmaInput() - { - // MMA inputs must be compatible with MMA size support - if constexpr(traits::is_storage) - { - return traits::is_soa_format && testSupportedMmaDim(); - } - else - { - return traits::is_mma_input && testSupportedMmaDim(); - } - } - - template - ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutMmaAcc() - { - // MMA acc must be compatible with MMA dim and MaxVW - if constexpr(traits::is_storage && !traits::is_interleaved) - { - return testRegisterLayoutSoa() - && testSupportedMmaDim() - && testSupportedMmaAccMaxVW(); - } - else - { - // Interleaved storage layouts and MmaInput are not compatible - // with MMA acc format - return traits::is_mma_acc && testSupportedMmaDim(); - } - } - - // As a predicate to is_layout_same or is_layout_orthogonal, their register parameters must - // be compatible (see above table). + // As a predicate to is_layout_same or is_layout_orthogonal, their register traits must + // be compatible as per above. template ROCWMMA_HOST_DEVICE constexpr static bool testCompatibleRegisterParams() { // Basic test: // Matching MmaDim, interleaving and validity + // Note: matching validity does not imply valid! + // Cannot mix valid with invalid layouts constexpr bool BaseTest = (traits_lhs::MmaDim == traits_rhs::MmaDim) && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) && (traits_lhs::is_valid == traits_rhs::is_valid); - // Storage <-> Storage must check Matrix compatibility - if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) - { - return testCompatibleMatrixParams() - && BaseTest; - } // MmaInput <-> MmaInput // MmaAcc <-> MmaAcc // Storage <-> MmaInput - else if constexpr((traits_lhs::is_mma_input && traits_rhs::is_mma_input) - || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc) - || (traits_lhs::is_storage && traits_rhs::is_mma_input) - || (traits_lhs::is_mma_input && traits_rhs::is_storage)) + // MmaDim must match and be supported + if constexpr((traits_lhs::is_mma_input && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc) + || (traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage)) { - return BaseTest; + return BaseTest && testSupportedMmaDim(); } - - // Storage <-> MmaAcc must also check MaxVW + // Storage <-> MmaAcc + // MmaAcc must check MaxVW + // MmaDim must match and be supported else if constexpr((traits_lhs::is_storage && traits_rhs::is_mma_acc) || (traits_lhs::is_mma_acc && traits_rhs::is_storage)) { using test_traits = conditional_t; - constexpr uint32_t ExpectedAccMaxVW - = ((bool)ROCWMMA_ARCH_GFX12) ? 8u - : ((bool)ROCWMMA_ARCH_GFX11 - || is_same::value) - ? 1u - : 4u; - - constexpr bool TestMmaAccMaxVW = (ExpectedAccMaxVW == test_traits::MaxVectorWidth); - - return TestMmaAccMaxVW && BaseTest; + if constexpr(test_traits::is_interleaved) + { + return ((test_traits::Format == RegisterLayout::Format::ACC_INT_A_MAJOR) + || (test_traits::Format == RegisterLayout::Format::ACC_INT_B_MAJOR)) + && BaseTest && testSupportedMmaDim(); + } + else + { + // Acc layout architecture quirks + constexpr uint32_t ExpectedAccMaxVW + = ((bool)ROCWMMA_ARCH_GFX12) ? 8u + : ((bool)ROCWMMA_ARCH_GFX11 + || is_same::value) + ? 1u + : 4u; + + constexpr bool TestMmaAccMaxVW + = (ExpectedAccMaxVW == test_traits::MaxVectorWidth); + + return TestMmaAccMaxVW && BaseTest && testSupportedMmaDim(); + } + } + // Storage <-> Storage + // Must check Matrix compatibility + // Not necessary to check MmaDim because doesn't involve MmaInput of MmaAcc + else if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) + { + return testCompatibleMatrixParams() + && BaseTest; } // MmaInput <-> MmaAcc not compatible else @@ -532,56 +383,24 @@ namespace rocwmma constexpr bool TestCompatibleParams = testCompatibleRegisterParams(); - // Test both register layouts in same format - constexpr bool TestFormatMatch = (traits_lhs::Format == traits_rhs::Format); - - if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) + if constexpr((traits_lhs::is_storage && traits_rhs::is_storage) + && (traits_lhs::is_interleaved && traits_rhs::is_interleaved)) { - // Exact match for same matrix and data layouts - constexpr bool TestExactMatch - = testMatrixLayoutSame() - && testDataLayoutSame(); - - // Orthogonal matrix layout and orthogonal data layout (implicit transpose) - constexpr bool TestImplicitTranspose - = testMatrixLayoutOrthogonal() - && testDataLayoutOrthogonal(); - - // Special case: interleaved VW dimension - // Check matching dims and if either one is == 1u - if constexpr(traits_lhs::is_interleaved && traits_rhs::is_interleaved) - { - constexpr bool TestIdentityQuirks - = (traits_lhs::DimPerThread == traits_rhs::DimPerThread) - && (traits_lhs::KPerThread == traits_rhs::KPerThread) - && ((traits_lhs::DimPerThread == 1u) || (traits_lhs::KPerThread == 1u)); - - return (TestExactMatch || TestImplicitTranspose || TestFormatMatch - || TestIdentityQuirks) - && TestCompatibleParams; - } - - return (TestExactMatch || TestImplicitTranspose || TestFormatMatch) - && TestCompatibleParams; + // Special case: interleaved layouts + // Check matching thread dims and if either one is == 1u. + // Format match not required because the in this case, + // register contents for SOA and AOS are identical + constexpr bool TestIdentityQuirks + = (traits_lhs::DimPerThread == traits_rhs::DimPerThread) + && (traits_lhs::KPerThread == traits_rhs::KPerThread) + && ((traits_lhs::DimPerThread == 1u) || (traits_lhs::KPerThread == 1u)); + + return TestIdentityQuirks && TestCompatibleParams; } - else // Mix of storage, MmaInput, MmaAcc + else { - // Test both sides for MmaInput compatibility - constexpr bool TestMmaInputMatch - = testRegisterLayoutMmaInput() - && testRegisterLayoutMmaInput() && TestCompatibleParams; - - // Test both sides for MmaAcc compatibility - constexpr bool TestMmaAccMatch = testRegisterLayoutMmaAcc() - && testRegisterLayoutMmaAcc() - && TestCompatibleParams; - - return (TestMmaInputMatch || TestMmaAccMatch || TestFormatMatch) - && TestCompatibleParams; + // Test both register layouts in same format + return TestCompatibleParams && (traits_lhs::Format == traits_rhs::Format); } } @@ -591,38 +410,38 @@ namespace rocwmma // Required not same and compatible params constexpr bool TestNotSame = !testRegisterLayoutSame(); + constexpr bool TestCompatibleParams = testCompatibleRegisterParams(); - // Path between valid AOS and SOA formats + // Identify valid paths in orthogonality. + // SOA <-> AOS + // ACC_INT_A_MAJOR <-> AOS, SOA + // ACC_INT_B_MAJOR <-> AOS, SOA + // Register layouts must be valid to be orthogonal + using RegisterLayout::Format; constexpr bool TestOpposingFormat - = (traits_lhs::is_soa_format && traits_rhs::is_aos_format) - || (traits_lhs::is_aos_format && traits_rhs::is_soa_format); - - // (testRegisterLayoutAos() && testRegisterLayoutSoa()) - // || (testRegisterLayoutSoa() && testRegisterLayoutAos()); - - if constexpr((traits_lhs::is_interleaved && traits_rhs::is_interleaved) - && (traits_lhs::is_mma_acc || traits_rhs::is_mma_acc)) - { - using RegisterLayoutMmaAcc - = conditional_t; - using RegisterLayoutOther - = conditional_t; - - // Special case: path between valid interleaved AOS/SOA and MmaAcc register layouts exists. - constexpr bool TestStorageToAcc - = testRegisterLayoutMmaAcc() - && (testRegisterLayoutAos() - || testRegisterLayoutSoa()); - - return (TestOpposingFormat || TestStorageToAcc) && TestNotSame - && TestCompatibleParams; - } - else - { - return TestOpposingFormat && TestNotSame && TestCompatibleParams; - } + = ((traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR + && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR + && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::ACC_INT_B_MAJOR)) + && (traits_lhs::is_valid && traits_rhs::is_valid); + + return TestNotSame && TestCompatibleParams && TestOpposingFormat; } // Checks if both RegisterLayout storages are the same with compatible params @@ -690,8 +509,8 @@ namespace std stream << "is_mma_acc: " << traits.is_mma_acc << std::endl; stream << "is_interleaved: " << traits.is_interleaved << std::endl; stream << "MmaDim: " << traits.MmaDim << std::endl; - stream << "is_aos_format: " << traits.is_aos_format << std::endl; - stream << "is_soa_format: " << traits.is_soa_format << std::endl; + // stream << "is_aos_format: " << traits.is_aos_format << std::endl; + // stream << "is_soa_format: " << traits.is_soa_format << std::endl; stream << "is_valid: " << traits.is_valid << std::endl; stream << "Format: " << traits.Format << std::endl; From b222868b5a2cfaf82c43c650c19c191439a7de1e Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Fri, 15 Nov 2024 00:09:36 +0000 Subject: [PATCH 11/36] Fixup interleaved layouts logic bugs. Add layout formats to fit all workflows --- .../rocwmma/internal/layout/layout.hpp | 39 +++-- .../internal/layout/matrix_layout_impl.hpp | 8 +- .../layout/register_layout_traits_impl.hpp | 147 ++++++++++++------ .../layout/register_layout_transforms.hpp | 91 ++++------- 4 files changed, 159 insertions(+), 126 deletions(-) diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index 83f07875..1d31d48a 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -147,10 +147,16 @@ namespace rocwmma // Format for data locality enum struct Format : uint32_t { - SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] - AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] - ACC_INT_A_MAJOR = 2u, // Interleaved Mma 'A' major order - ACC_INT_B_MAJOR = 3u, // Interleaved Mma 'B' major order + SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] + AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] + SOA_INT = 2u, // SOA interleaved + AOS_INT = 3u, // AOS interleaved + ACC_INT_A_MAJOR = 4u, // Interleaved MmaAcc 'A' major order + ACC_INT_B_MAJOR = 5u, // Interleaved MmaAcc 'B' major order + WMMA_INPUT_GFX11 = 6u, // Gfx11 input format + WMMA_ACC_GFX11 = 7u, // Gfx11 acc format + WMMA_ACC_INT_A_MAJOR_GFX11 = 8u, // Gfx11 interleaved MmaAcc 'A' major order + WMMA_ACC_INT_B_MAJOR_GFX11 = 9u, // Gfx11 interleaved MmaAcc 'B' major order Invalid, // Invalid register format }; @@ -161,7 +167,9 @@ namespace rocwmma }; // A mnemonic used to describe the register layout is suitable for mma input for A/B - template + template struct MmaInput { }; @@ -211,13 +219,16 @@ namespace std { inline ostream& operator<<(ostream& stream, rocwmma::RegisterLayout::Format const& fmt) { - return stream << (fmt == rocwmma::RegisterLayout::Format::AOS ? "AOS" - : (fmt == rocwmma::RegisterLayout::Format::SOA) ? "SOA" - : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_A_MAJOR) - ? "ACC_INT_A_MAJOR" - : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_B_MAJOR) - ? "ACC_INT_B_MAJOR" - : "INVALID"); + return stream + << (fmt == rocwmma::RegisterLayout::Format::AOS ? "AOS" + : (fmt == rocwmma::RegisterLayout::Format::SOA) ? "SOA" + : (fmt == rocwmma::RegisterLayout::Format::AOS_INT) ? "AOS_INT" + : (fmt == rocwmma::RegisterLayout::Format::SOA_INT) ? "SOA_INT" + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_A_MAJOR) ? "ACC_INT_A_MAJOR" + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_B_MAJOR) ? "ACC_INT_B_MAJOR" + : (fmt == rocwmma::RegisterLayout::Format::WMMA_INPUT_GFX11) ? "WMMA_INPUT_GFX11" + : (fmt == rocwmma::RegisterLayout::Format::WMMA_ACC_GFX11) ? "WMMA_ACC_GFX11" + : "INVALID"); } template @@ -238,8 +249,8 @@ namespace std template inline ostream& - operator<<(ostream& stream, - rocwmma::RegisterLayout::MmaAcc const& register_layout) + operator<<(ostream& stream, + rocwmma::RegisterLayout::MmaAcc const& register_layout) { return stream << "MmaAcc<" << MmaDim << ", " << Interleaved << ", " << Fmt << ">"; } diff --git a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp index a14e2cfe..61db6612 100644 --- a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -1083,7 +1083,7 @@ namespace std rocwmma::MatrixLayout:: RowOrthoVW const& matrix_layout) { - return stream << "ColOrthoVW<" << BlockDim << ", " << BlockK << ", " + return stream << "RowOrthoVW<" << BlockDim << ", " << BlockK << ", " << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " << MaxVectorWidth << ">"; } @@ -1098,7 +1098,7 @@ namespace std rocwmma::MatrixLayout:: RowInlineVW const& matrix_layout) { - return stream << "ColInlineVW<" << BlockDim << ", " << BlockK << ", " + return stream << "RowInlineVW<" << BlockDim << ", " << BlockK << ", " << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " << MaxVectorWidth << ">"; } @@ -1131,7 +1131,7 @@ namespace std rocwmma::MatrixLayout::RowOrthoInt const& matrix_layout) { - return stream << "ColOrthoInt<" << BlockDim << ", " << BlockK << ", " + return stream << "RowOrthoInt<" << BlockDim << ", " << BlockK << ", " << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK << ">"; } @@ -1142,7 +1142,7 @@ namespace std rocwmma::MatrixLayout::RowInlineInt const& matrix_layout) { - return stream << "ColInlineInt<" << BlockDim << ", " << BlockK << ", " + return stream << "RowInlineInt<" << BlockDim << ", " << BlockK << ", " << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK << ">"; } diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 6b066e40..def554c8 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -187,10 +187,30 @@ namespace rocwmma using rocwmma::RegisterLayout::Format; if constexpr(traits::is_mma_input) { - return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + if constexpr(traits::is_interleaved) + { + return (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT); + } + else + { + return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + } } else if constexpr(traits::is_mma_acc) { +#if ROCWMMA_ARCH_GFX11 + if constexpr(traits::is_interleaved) + { + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) + || (traits::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11); + } + else + { + return (traits::Format == WMMA_ACC_GFX11); + } +#else if constexpr(traits::is_interleaved) { // Intermediate accumulation format for interleaved layout @@ -201,15 +221,54 @@ namespace rocwmma { return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); } +#endif // ROCWMMA_ARCH_GFX11 } else { return traits::is_storage && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) + || (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT) || (traits::Format == Format::Invalid)); } } + template + ROCWMMA_HOST_DEVICE constexpr static inline auto registerFormat() + { + using traits = register_layout_traits; + using rocwmma::RegisterLayout::Format; + + // MmaInput and MmaAcc are statically assigned + if constexpr(traits::is_mma_input || traits::is_mma_acc) + { + return traits::Format; + } + // Determine the register format of the current storage layout + // based on the layout traits. + else if constexpr(traits::is_storage) + { + if constexpr(traits::is_interleaved) + { + return testStorageLayoutAos() + ? Format::AOS_INT + : (testStorageLayoutSoa() ? Format::SOA_INT + : Format::Invalid); + } + else + { + return testStorageLayoutAos() + ? Format::AOS + : (testStorageLayoutSoa() ? Format::SOA + : Format::Invalid); + } + } + else + { + return Format::Invalid; + } + } + template struct register_layout_derived_traits { @@ -223,13 +282,8 @@ namespace rocwmma using MatrixLayout = MatrixLayoutInternal; using DataLayout = DataLayoutInternal; - // Determine the register format of the current storage layout constexpr static RegisterLayout::Format Format - = testStorageLayoutAos>() - ? RegisterLayout::Format::AOS - : (testStorageLayoutSoa>() - ? RegisterLayout::Format::SOA - : RegisterLayout::Format::Invalid); + = registerFormat>(); constexpr static bool is_valid = testStorageLayoutIdentity>() @@ -321,28 +375,27 @@ namespace rocwmma // MmaInput <-> MmaInput // MmaAcc <-> MmaAcc - // Storage <-> MmaInput - // MmaDim must match and be supported if constexpr((traits_lhs::is_mma_input && traits_rhs::is_mma_input) - || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc) - || (traits_lhs::is_storage && traits_rhs::is_mma_input) - || (traits_lhs::is_mma_input && traits_rhs::is_storage)) + || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc)) { - return BaseTest && testSupportedMmaDim(); + return BaseTest; } // Storage <-> MmaAcc - // MmaAcc must check MaxVW - // MmaDim must match and be supported - else if constexpr((traits_lhs::is_storage && traits_rhs::is_mma_acc) + // Storage <-> MmaInput + // Storage must be valid layout + // Non-interleaved MmaAcc must check MaxVW + else if constexpr((traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage) + || (traits_lhs::is_storage && traits_rhs::is_mma_acc) || (traits_lhs::is_mma_acc && traits_rhs::is_storage)) { - using test_traits = conditional_t; + using storage_traits + = conditional_t; + using mma_traits = conditional_t; - if constexpr(test_traits::is_interleaved) + if constexpr(mma_traits::is_mma_input || mma_traits::is_interleaved) { - return ((test_traits::Format == RegisterLayout::Format::ACC_INT_A_MAJOR) - || (test_traits::Format == RegisterLayout::Format::ACC_INT_B_MAJOR)) - && BaseTest && testSupportedMmaDim(); + return BaseTest && storage_traits::is_valid; } else { @@ -350,19 +403,18 @@ namespace rocwmma constexpr uint32_t ExpectedAccMaxVW = ((bool)ROCWMMA_ARCH_GFX12) ? 8u : ((bool)ROCWMMA_ARCH_GFX11 - || is_same::value) + || is_same::value) ? 1u : 4u; constexpr bool TestMmaAccMaxVW - = (ExpectedAccMaxVW == test_traits::MaxVectorWidth); + = (ExpectedAccMaxVW == storage_traits::MaxVectorWidth); - return TestMmaAccMaxVW && BaseTest && testSupportedMmaDim(); + return TestMmaAccMaxVW && BaseTest && storage_traits::is_valid; } } // Storage <-> Storage // Must check Matrix compatibility - // Not necessary to check MmaDim because doesn't involve MmaInput of MmaAcc else if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) { return testCompatibleMatrixParams(); - if constexpr((traits_lhs::is_storage && traits_rhs::is_storage) - && (traits_lhs::is_interleaved && traits_rhs::is_interleaved)) + // General case the formats match + constexpr bool TestFormatMatch = (traits_lhs::Format == traits_rhs::Format); + + if constexpr((traits_lhs::is_interleaved && traits_rhs::is_interleaved) + && ((traits_lhs::is_storage && traits_rhs::is_storage) + || (traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage))) { + using storage_traits + = conditional_t; + // Special case: interleaved layouts // Check matching thread dims and if either one is == 1u. - // Format match not required because the in this case, - // register contents for SOA and AOS are identical + // Register contents will be identical, regardless if the format matches. constexpr bool TestIdentityQuirks - = (traits_lhs::DimPerThread == traits_rhs::DimPerThread) - && (traits_lhs::KPerThread == traits_rhs::KPerThread) - && ((traits_lhs::DimPerThread == 1u) || (traits_lhs::KPerThread == 1u)); + = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); - return TestIdentityQuirks && TestCompatibleParams; + return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); } else { // Test both register layouts in same format - return TestCompatibleParams && (traits_lhs::Format == traits_rhs::Format); + return TestCompatibleParams && TestFormatMatch; } } @@ -423,21 +480,25 @@ namespace rocwmma constexpr bool TestOpposingFormat = ((traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::SOA_INT) || (traits_lhs::Format == Format::ACC_INT_A_MAJOR - && traits_rhs::Format == Format::SOA) + && traits_rhs::Format == Format::SOA_INT) || (traits_lhs::Format == Format::ACC_INT_A_MAJOR - && traits_rhs::Format == Format::AOS) - || (traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) - || (traits_lhs::Format == Format::AOS + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) || (traits_lhs::Format == Format::ACC_INT_B_MAJOR - && traits_rhs::Format == Format::SOA) + && traits_rhs::Format == Format::SOA_INT) || (traits_lhs::Format == Format::ACC_INT_B_MAJOR - && traits_rhs::Format == Format::AOS) - || (traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) - || (traits_lhs::Format == Format::AOS + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR)) && (traits_lhs::is_valid && traits_rhs::is_valid); @@ -509,8 +570,6 @@ namespace std stream << "is_mma_acc: " << traits.is_mma_acc << std::endl; stream << "is_interleaved: " << traits.is_interleaved << std::endl; stream << "MmaDim: " << traits.MmaDim << std::endl; - // stream << "is_aos_format: " << traits.is_aos_format << std::endl; - // stream << "is_soa_format: " << traits.is_soa_format << std::endl; stream << "is_valid: " << traits.is_valid << std::endl; stream << "Format: " << traits.Format << std::endl; diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index d51f5076..8c2735f4 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -37,12 +37,8 @@ namespace rocwmma using LayoutTraits_impl::register_layout_traits; // Keeps things a bit more tidy. Quick access to register layout traits. -#define reg_traits_lhs register_layout_traits -#define reg_traits_rhs register_layout_traits - -// Quick access to matrix layout traits, that are embedded in the register layout traits. -#define mat_traits_lhs matrix_layout_traits -#define mat_traits_rhs matrix_layout_traits +#define traits_lhs register_layout_traits +#define traits_rhs register_layout_traits // Note: If you arrive at an undefined register_transform error, it is likely // the layout transformation is not currently supported. Need to either implement @@ -65,81 +61,50 @@ namespace rocwmma } }; - // Non-interleaved orthogonal transforms - // + // Apply paths between orthogonal transforms template struct register_layout_transform< RegisterLayoutLhs, RegisterLayoutRhs, - enable_if_t<(!mat_traits_lhs::is_interleaved || !mat_traits_rhs::is_interleaved) + enable_if_t<(traits_lhs::is_register_layout && traits_rhs::is_register_layout) && is_layout_orthogonal_v>> { template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) { - // Orthogonality promises: - // BlockDim, KDim, MaxVW match on lhs and rhs + using RegisterLayout::Format; - // Inline to ortho layout (AOS -> SOA) - if constexpr(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) + // Non-interleaved AOS to SOA + if constexpr(traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) { + using storage_traits + = conditional_t; return Transforms:: - AosToSoa::exec( + AosToSoa::exec( forward(v)); } - // Ortho to inline layout (SOA -> AOS) - else if constexpr(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) + else if constexpr(traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::AOS) { + using storage_traits + = conditional_t; return Transforms:: - SoaToAos::exec( + SoaToAos::exec( forward(v)); } - // MmaInput (ortho) to inline layout (SOA -> AOS) - else if constexpr(reg_traits_lhs::is_mma_input) - { - return Transforms:: - SoaToAos::exec( - forward(v)); - } - else - { - static_assert(0, "Shouldn't get here"); - return v; - } - } - }; - - // Interleaved orthogonal transforms - template - struct register_layout_transform< - RegisterLayoutLhs, - RegisterLayoutRhs, - enable_if_t<(mat_traits_lhs::is_interleaved || mat_traits_rhs::is_interleaved) - && is_layout_orthogonal_v>> - { - template - ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) - { - // Orthogonality promises: - // BlockDim, KDim, MmaDim match on lhs and rhs - - // Inline to ortho layout (AOS -> SOA) - if constexpr(mat_traits_lhs::is_col_inline || mat_traits_lhs::is_row_inline) - { - // Leading dim VW - return interleave<1u, mat_traits_lhs::DimPerThread>(forward(v)); - } - // Ortho to inline layout (SOA -> AOS) - else if constexpr(mat_traits_lhs::is_col_ortho || mat_traits_lhs::is_row_ortho) + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::SOA_INT) { - // KDim VW - return interleave<1u, mat_traits_lhs::KPerThread>(forward(v)); + using storage_traits + = conditional_t; + return interleave<1u, storage_traits::DimPerThread>(forward(v)); } - // MmaInput (ortho) to inline - else if constexpr(reg_traits_lhs::is_mma_input) + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::AOS_INT) { - // Leading dim VW - return interleave<1u, mat_traits_rhs::DimPerThread>(forward(v)); + using storage_traits + = conditional_t; + return interleave<1u, storage_traits::KPerThread>(forward(v)); } else { @@ -149,10 +114,8 @@ namespace rocwmma } }; -#undef reg_traits_lhs -#undef reg_traits_rhs -#undef mat_traits_lhs -#undef mat_traits_rhs +#undef traits_lhs +#undef traits_rhs } // namespace RegisterTransform_impl From a5c23e4547b16a3482210486ea17eb7a5d009648 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Fri, 15 Nov 2024 18:35:19 +0000 Subject: [PATCH 12/36] Fix compiler unroll issue with function arg --- library/include/rocwmma/internal/opaque_load.hpp | 11 ++++++++--- library/include/rocwmma/internal/opaque_store.hpp | 15 ++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/library/include/rocwmma/internal/opaque_load.hpp b/library/include/rocwmma/internal/opaque_load.hpp index 06fdc066..94d63302 100644 --- a/library/include/rocwmma/internal/opaque_load.hpp +++ b/library/include/rocwmma/internal/opaque_load.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -144,12 +144,17 @@ namespace rocwmma MatrixLayout::strideCounts()), "IOCount inconsistent with total strides"); + // Initialize the stride details as constexpr + // so that the compiler can optimize them as args. + constexpr auto strideCounts = MatrixLayout::strideCounts(); + constexpr auto strides = MatrixLayout::strides(); + // Unroll loading in each strided dimension unroll_right(it, dataPtr + DataLayout::fromMatrixCoord(baseOffset2d, ldm), ldm, - MatrixLayout::strideCounts(), - MatrixLayout::strides()); + strideCounts, + strides); } }; diff --git a/library/include/rocwmma/internal/opaque_store.hpp b/library/include/rocwmma/internal/opaque_store.hpp index 7afbd2e7..2880806a 100644 --- a/library/include/rocwmma/internal/opaque_store.hpp +++ b/library/include/rocwmma/internal/opaque_store.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -86,7 +86,7 @@ namespace rocwmma if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll - for(int i = 0; i < strideCount; i++) + for(unsigned int i = 0; i < strideCount; i++) { Traits::Storer::exec(dataPtr, *in); dataPtr += strideOffset; @@ -97,7 +97,7 @@ namespace rocwmma else { #pragma unroll - for(int i = 0; i < strideCount; i++) + for(unsigned int i = 0; i < strideCount; i++) { unroll_right(dataPtr, in, ldm, strideCounts, strides2d); dataPtr += strideOffset; @@ -121,11 +121,16 @@ namespace rocwmma MatrixLayout::strideCounts()), "IOCount inconsistent with total strides"); + // Initialize the stride details as constexpr + // so that the compiler can optimize them as args. + constexpr auto strideCounts = MatrixLayout::strideCounts(); + constexpr auto strides = MatrixLayout::strides(); + unroll_right(dataPtr + DataLayout::fromMatrixCoord(baseOffset2d, ldm), it, ldm, - MatrixLayout::strideCounts(), - MatrixLayout::strides()); + strideCounts, + strides); } }; From 678a2d34ece2d3733eb8f7c4ab574ff96b4c76d1 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Sat, 16 Nov 2024 00:37:59 +0000 Subject: [PATCH 13/36] Deploy new mma workflow --- .../rocwmma/internal/coop_io_config.hpp | 12 ++- .../include/rocwmma/internal/io_config.hpp | 22 ++++- .../include/rocwmma/internal/io_layout.hpp | 96 ++++++++++++++----- .../layout/register_layout_transforms.hpp | 1 + .../rocwmma/internal/transforms_impl.hpp | 3 +- library/include/rocwmma/rocwmma_coop_impl.hpp | 37 ++++--- library/include/rocwmma/rocwmma_impl.hpp | 52 ++++++---- .../rocwmma/rocwmma_transforms_impl.hpp | 95 ++++-------------- 8 files changed, 183 insertions(+), 135 deletions(-) diff --git a/library/include/rocwmma/internal/coop_io_config.hpp b/library/include/rocwmma/internal/coop_io_config.hpp index 9b39fded..9d1a1767 100644 --- a/library/include/rocwmma/internal/coop_io_config.hpp +++ b/library/include/rocwmma/internal/coop_io_config.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #ifndef ROCWMMA_COOP_IO_CONFIG_HPP #define ROCWMMA_COOP_IO_CONFIG_HPP +#include "./layout/register_layout_transforms.hpp" #include "coop_load.hpp" #include "coop_store.hpp" #include "io_layout.hpp" @@ -85,6 +86,15 @@ namespace rocwmma typename IOLayout::MatrixLayout, IOLayout::VW>; + using PostLoadXForm = register_layout_transform; + + using PreMmaXForm = register_layout_transform; + + using PreStoreXForm = register_layout_transform; + using Storer = CooperativeStore; + using PostLoadXForm = register_layout_transform; + + using PreMmaXForm = register_layout_transform; + + using PreStoreXForm = register_layout_transform; + using Storer = OpaqueStore struct IOConfig { - using IOShape = IOShape; - using IOTraits = IOTraits; - using PackUtil = PackUtil; + using IOShape = IOShape; + using IOLayout = IOLayout; + using IOTraits = IOTraits; + using PackUtil = PackUtil; using Broadcaster = Broadcast; + + using PreMmaXForm = register_layout_transform; }; /** @}*/ diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index ca88bcc8..6c3b5507 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -138,15 +138,32 @@ namespace rocwmma VW = is_same::value || BlockDim > 32 ? MaxVW : 1u }; + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts // Layout profile for 'matrix_a': ColNT for small frags, Col for large frags - using Profile = conditional_t< - BlockDim <= 32, - LayoutProfile::template ColNT, - LayoutProfile::template Col>; - - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + using SmallDimMatrixLayout + = conditional_t, + MatrixLayout::ColOrthoVW, + MatrixLayout::ColOrthoVW>; + + using LargeDimMatrixLayout + = conditional_t, + MatrixLayout::ColInlineVW, + MatrixLayout::ColOrthoVW>; + + using MatrixLayout + = conditional_t; + + // Register layouts + using MemoryLayout = RegisterLayout::Storage; + using FragmentLayout = MemoryLayout; + using MmaLayout = RegisterLayout::MmaInput; }; template ::value || BlockDim > 32 ? MaxVW : 1u }; - // Layout profile for 'matrix_b': RowNT for small frags, Row for large frags - using Profile = conditional_t< - BlockDim <= 32, - LayoutProfile::template RowNT, - LayoutProfile::template Row>; + // DataLayout + using DataLayout = DataLayout::template Array1d; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + // Matrix Layouts + // Layout profile for 'matrix_a': ColNT for small frags, Col for large frags + using SmallDimMatrixLayout + = conditional_t, + MatrixLayout::RowOrthoVW, + MatrixLayout::RowOrthoVW>; + + using LargeDimMatrixLayout + = conditional_t, + MatrixLayout::RowInlineVW, + MatrixLayout::RowOrthoVW>; + + using MatrixLayout + = conditional_t; + + // Register layouts + using MemoryLayout = RegisterLayout::Storage; + using FragmentLayout = MemoryLayout; + using MmaLayout = RegisterLayout::MmaInput; }; template ::value ? MaxVW : 1u }; - // Layout profile for 'accumulator' set to RowNT - using Profile - = LayoutProfile::template RowNT; + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Layout profile for 'accumulator' set to RowNT, small frags + using MatrixLayout + = conditional_t, + MatrixLayout::RowOrthoVW, + MatrixLayout::RowOrthoVW>; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + // Register layouts + using MemoryLayout = RegisterLayout::Storage; + using MmaLayout = RegisterLayout::MmaAcc; + using FragmentLayout = MemoryLayout; }; template struct IOLayout { - // No layout mapping without VW, MaxVW and DataLayoutT info + using MemoryLayout = void; + using MmaLayout = RegisterLayout::MmaAcc; + using FragmentLayout = MmaLayout; }; } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 8c2735f4..ed9abfa6 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -26,6 +26,7 @@ #ifndef ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP #define ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP +#include "../transforms.hpp" #include "layout.hpp" #include "layout_traits.hpp" diff --git a/library/include/rocwmma/internal/transforms_impl.hpp b/library/include/rocwmma/internal/transforms_impl.hpp index 97981e2e..10959b42 100644 --- a/library/include/rocwmma/internal/transforms_impl.hpp +++ b/library/include/rocwmma/internal/transforms_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ #include "io_traits.hpp" #include "pack_util.hpp" #include "permute.hpp" +#include "swizzle.hpp" #include "utils.hpp" #include "vector_util.hpp" diff --git a/library/include/rocwmma/rocwmma_coop_impl.hpp b/library/include/rocwmma/rocwmma_coop_impl.hpp index de563205..09a05383 100644 --- a/library/include/rocwmma/rocwmma_coop_impl.hpp +++ b/library/include/rocwmma/rocwmma_coop_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -66,8 +66,10 @@ namespace rocwmma uint32_t waveCount) { - using FragT = decay_t; - using Loader = typename GetCoopIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -82,6 +84,9 @@ namespace rocwmma // Note: the frag will only be partially filled with useful data. // Layout and thread locality is not guaranteed. Loader::exec(frag.mAccess, data, ldm, waveIndex, waveCount); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template ; - using Loader = typename GetCoopIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -156,6 +163,9 @@ namespace rocwmma // Note: the frag will only be partially filled with useful data. // Layout and thread locality is not guaranteed. Loader::template exec(frag.mAccess, data, ldm, waveIndex); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template ; - using Storer = typename GetCoopIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity checks static_assert(!is_same::value, @@ -204,7 +216,7 @@ namespace rocwmma // Implicit unpack and store // Note: the frag is only be partially filled with useful data. // Layout and thread locality is not guaranteed. - Storer::exec(data, frag.mAccess, ldm, waveIndex, waveCount); + Storer::exec(data, PreStore::exec(frag.mAccess), ldm, waveIndex, waveCount); } template ; - using Storer = typename GetCoopIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity checks static_assert(!is_same::value, @@ -281,7 +294,7 @@ namespace rocwmma // Implicit unpack and store // Note: the frag is only be partially filled with useful data. // Layout and thread locality is not guaranteed. - Storer::template exec(data, frag.mAccess, ldm, waveIndex); + Storer::template exec(data, PreStore::exec(frag.mAccess), ldm, waveIndex); } } // namespace rocwmma diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index 8726f509..cf03d959 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -227,8 +227,10 @@ namespace rocwmma const DataT* data, uint32_t ldm) { - using FragT = decay_t; - using Loader = typename GetIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -241,6 +243,9 @@ namespace rocwmma // Load then implicit pack Loader::exec(frag.mAccess, data, ldm); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template @@ -274,8 +279,10 @@ namespace rocwmma fragment const& frag, uint32_t ldm) { - using FragT = decay_t; - using Storer = typename GetIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity check static_assert(!is_same::value, @@ -287,7 +294,7 @@ namespace rocwmma "Fragment access and store input types do not match"); // Implicit unpack and then store - Storer::exec(data, frag.mAccess, ldm); + Storer::exec(data, PreStore::exec(frag.mAccess), ldm); } template @@ -326,11 +333,21 @@ namespace rocwmma fragment const& b, fragment const& c) { - using FragA = decay_t; - using FragB = decay_t; + using FragA = decay_t; + using FragB = decay_t; + using FragAcc = decay_t; + + using IOConfigA = GetIOConfig_t; + using IOConfigB = GetIOConfig_t; + using IOConfigAcc = GetIOConfig_t; - using IOConfigA = GetIOConfig_t; - using IOConfigB = GetIOConfig_t; + using PreMmaA = typename IOConfigA::PreMmaXForm; + using PreMmaB = typename IOConfigB::PreMmaXForm; + using PreMmaAcc = typename IOConfigAcc::PreMmaXForm; + + using PackA = typename IOConfigA::PackUtil; + using PackB = typename IOConfigB::PackUtil; + using PackAcc = typename IOConfigAcc::PackUtil; // Sanity checks static_assert((IOConfigA::IOShape::BlockDim >= 16) && (IOConfigB::IOShape::BlockDim >= 16) @@ -345,22 +362,19 @@ namespace rocwmma typename IOConfigB::IOLayout::MatrixLayout>, "Input fragment matrix layouts are not orthogonal"); - static_assert(is_layout_same_v, + static_assert(is_layout_same_v, "Input fragment register layouts do not match"); - // static_assert(is_same_v>, - // "Input fragment register layouts are not mfma friendly"); - // Gfx9 uses MFMA, gfx11 uses WMMA - using MMA = conditional_t, Wmma>; // mma functions operate on packed vectors - (*d) = MMA::exec(*a, *b, *c); + (*d) = MMA::exec(PackA::pack(PreMmaA::exec(a.mAccess)), + PackB::pack(PreMmaB::exec(b.mAccess)), + PackAcc::pack(PreMmaAcc::exec(c.mAccess))); } ROCWMMA_DEVICE void synchronize_workgroup() diff --git a/library/include/rocwmma/rocwmma_transforms_impl.hpp b/library/include/rocwmma/rocwmma_transforms_impl.hpp index cd18da5d..f58cad3a 100644 --- a/library/include/rocwmma/rocwmma_transforms_impl.hpp +++ b/library/include/rocwmma/rocwmma_transforms_impl.hpp @@ -95,8 +95,8 @@ namespace rocwmma typename IOConfigB::IOLayout::MatrixLayout>, "Matrix Layouts are not orthogonal"); - static_assert(is_layout_same_v, + static_assert(is_layout_same_v, "Register layouts do not match"); public: @@ -150,9 +150,9 @@ namespace rocwmma typename IOConfigB::IOLayout::MatrixLayout>, "Matrix Layouts are not orthogonal"); - static_assert(is_layout_same_v, - "Register layouts do not match"); + static_assert(is_layout_same_v, + "Fragment register layouts do not match"); public: // Interface @@ -183,25 +183,6 @@ namespace rocwmma template struct ApplyDataLayout; - // Same layout case - template - struct ApplyDataLayout, - DataLayoutT> - { - // Interface - using Type = fragment; - template - ROCWMMA_DEVICE constexpr static inline Type const& exec(Type const& frag) - { - return frag; - } - }; - // Other layout case template , NewDataLayoutT> { - private: - using FragIn = fragment; - using FragOut = fragment; - - using IOConfigIn = GetIOConfig_t; + using Type = fragment; - using RegisterLayoutIn = typename GetIOConfig_t::IOLayout::RegisterLayout; - using RegisterLayoutOut = typename GetIOConfig_t::IOLayout::RegisterLayout; - - // Matrix context, BlockDim and KDim implicitly the same due to re-use of - // MatrixT, BlockM, BlockN, BlockK - - public: - // Interface - using Type = FragOut; - - // Optimal case: input and output register layouts match - template - && is_layout_same_v, - int> - = 0> + template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(FragT const& frag) { - return reinterpret_cast(frag); - } + static_assert( + is_same_v, FragT>, + "Input fragment types do not match"); - // Input and output register layouts do not match: must transform using AOS<->SOA - template - && !is_layout_same_v, - int> - = 0> - ROCWMMA_DEVICE constexpr static inline auto exec(FragT const& frag) - { - // TODO: Make sure to use coop configs to get the right MaxVW!!! - // using IOConfigCoopIn = GetCoopIOConfig_t; - // constexpr uint32_t BlockDim = IOConfigCoop::IOShape::BlockDim; - // constexpr uint32_t MaxVW = IOConfigCoop::IOLayout::MaxVW; - // using RegisterLayoutIncoming = typename IOConfigCoop::IOLayout::RegisterLayout; - - // // Target layouts - // using AosLayout = RegisterLayout::template Aos; - // using SoaLayout = RegisterLayout::template Soa; - - using SrcRegLayout = - typename GetCoopIOConfig_t::IOLayout::RegisterLayout; - using DstRegLayout = - typename GetCoopIOConfig_t::IOLayout::RegisterLayout; - - auto result = FragOut{}; - result.mAccess - = register_layout_transform::exec(frag.mAccess); + using DstFrag = Type; + + // Make sure to use coop configs to get the right MaxVW!!! + using SrcLayout = + typename GetCoopIOConfig_t::IOLayout::FragmentLayout; + using DstLayout = + typename GetCoopIOConfig_t::IOLayout::FragmentLayout; + auto result = DstFrag{}; + result.mAccess + = register_layout_transform::exec(frag.mAccess); return result; } }; From 622a25623a69ce128d55d9595f076dbcc621304f Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 19 Nov 2024 15:50:14 +0000 Subject: [PATCH 14/36] Fix include issues and io_shape test --- .../include/rocwmma/internal/permute_impl.hpp | 3 +- .../include/rocwmma/internal/transforms.hpp | 3 +- test/unit/io_shape_test/detail/io_shape.hpp | 82 +++++++++---------- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/library/include/rocwmma/internal/permute_impl.hpp b/library/include/rocwmma/internal/permute_impl.hpp index 2f6a82c5..0e2fcece 100644 --- a/library/include/rocwmma/internal/permute_impl.hpp +++ b/library/include/rocwmma/internal/permute_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #ifndef ROCWMMA_PERMUTE_IMPL_HPP #define ROCWMMA_PERMUTE_IMPL_HPP +#include "mapping_util.hpp" #include "permute.hpp" namespace rocwmma diff --git a/library/include/rocwmma/internal/transforms.hpp b/library/include/rocwmma/internal/transforms.hpp index 82e61dcd..49ef6d9b 100644 --- a/library/include/rocwmma/internal/transforms.hpp +++ b/library/include/rocwmma/internal/transforms.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include "transforms_impl.hpp" #include "vector.hpp" +#include "vector_iterator.hpp" namespace rocwmma { diff --git a/test/unit/io_shape_test/detail/io_shape.hpp b/test/unit/io_shape_test/detail/io_shape.hpp index 7f23eecf..ca8930c1 100644 --- a/test/unit/io_shape_test/detail/io_shape.hpp +++ b/test/unit/io_shape_test/detail/io_shape.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -58,49 +58,43 @@ namespace rocwmma template bool waveTest() { - bool err = false; - constexpr auto BlockDim = std::is_same_v ? BlockM : BlockN; - constexpr auto KDim = std::is_same_v ? BlockM : BlockK; - - constexpr auto MaxVW - = std::is_same_v ? detail:: - MaxVWSelector:: - Result - : std::is_same_v - ? detail:: - MaxVWSelector:: - Result - : (std::is_same::value || ROCWMMA_ARCH_GFX11 ? 1u : 4u); - constexpr auto VW - = std::is_same_v - ? std::is_same::value || BlockDim > 32 ? MaxVW : 1u - : std::is_same_v - ? (std::is_same::value || BlockDim > 32 ? MaxVW : 1u) - : (std::is_same::value ? MaxVW : 1u); - - using RowNT - = LayoutProfile::template RowNT; - using ColNT - = LayoutProfile::template ColNT; - - using Row = LayoutProfile::template Row; - using Col = LayoutProfile::template Col; - - using Profile = typename std::conditional_t< - std::is_same_v, - std::conditional_t, - std::conditional_t, - std::conditional_t, - RowNT>>; - - using DataLayout = DataLayout::template Array1d; - - using IOLayout = IOLayout; - - err |= (IOLayout::MaxVW != MaxVW); - err |= (IOLayout::VW != VW); - err |= (!std::is_same::value); - err |= (!std::is_same::value); + bool err = false; + + // Accum requires WaveCount > 1 + if constexpr(!std::is_same_v || WaveCount == 1) + { + constexpr auto BlockDim = std::is_same_v ? BlockM : BlockN; + constexpr auto KDim = std::is_same_v ? BlockM : BlockK; + + using detail::MaxVWSelector; + using detail::MmaDimSelector; + + constexpr auto ExpectMaxVW + = MaxVWSelector::Result; + + constexpr auto ExpectVW + = std::is_same_v + ? std::is_same::value || BlockDim > 32 + ? ExpectMaxVW + : 1u + : std::is_same_v + ? (std::is_same::value || BlockDim > 32 + ? ExpectMaxVW + : 1u) + : (std::is_same::value ? ExpectMaxVW : 1u); + + constexpr auto ExpectMmaDim = MmaDimSelector::Result; + + using IOLayout = IOLayout; + using IOLayoutInt + = IOLayoutInt; + using ExpectDataLayout = DataLayout::template Array1d; + + err |= (IOLayout::MaxVW != ExpectMaxVW); + err |= (IOLayout::VW != ExpectVW); + err |= (IOLayoutInt::MmaDim != ExpectMmaDim); + err |= (!std::is_same_v); + } return err; } From 2c08b3667d976f74b4b709b371824b71a8c40856 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 19 Nov 2024 15:51:25 +0000 Subject: [PATCH 15/36] Add interleaved layout IOLayoutInt --- .../rocwmma/internal/coop_io_config.hpp | 4 +- .../include/rocwmma/internal/io_config.hpp | 4 +- .../include/rocwmma/internal/io_layout.hpp | 390 +++++++++++++----- 3 files changed, 297 insertions(+), 101 deletions(-) diff --git a/library/include/rocwmma/internal/coop_io_config.hpp b/library/include/rocwmma/internal/coop_io_config.hpp index 9d1a1767..b3bf287e 100644 --- a/library/include/rocwmma/internal/coop_io_config.hpp +++ b/library/include/rocwmma/internal/coop_io_config.hpp @@ -86,14 +86,14 @@ namespace rocwmma typename IOLayout::MatrixLayout, IOLayout::VW>; - using PostLoadXForm = register_layout_transform; using PreMmaXForm = register_layout_transform; using PreStoreXForm = register_layout_transform; + typename IOLayout::StorageLayout>; using Storer = CooperativeStore; - using PostLoadXForm = register_layout_transform; using PreMmaXForm = register_layout_transform; using PreStoreXForm = register_layout_transform; + typename IOLayout::StorageLayout>; using Storer = OpaqueStore struct MaxVWSelector { - private: - enum : uint32_t - { - // For small block sizes (16, 32): - // Best to keep MaxVW high and reduce splits amongst waves. - WaveCountFactor = (BlockDim <= 32) ? 1u : WaveCount, - - // Total number of elements in a single I/O operation - ElementsPerIO = Constants::AMDGCN_WAVE_SIZE * TestWidth * WaveCountFactor, - - // Total number of elements for the entire block - ElementCount = BlockDim * BlockK, - - // Ensure that for MaxVW: - // - A minimum of one IO from each wave can fit - // - A balanced multiple of IOs from each wave - ElementCountTest - = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0), - - // Currently, all layouts are using ColOrthoVW. This means that VW must be less than BlockK - LeadingDimTest = (TestWidth <= BlockK), - - MaxVectorWidth = (bool)ElementCountTest && (bool)LeadingDimTest - ? TestWidth - : MaxVWSelector::Result, - }; + // For small block sizes (16, 32): + // Best to keep MaxVW high and reduce splits amongst waves. + static constexpr uint32_t WaveCountFactor = (BlockDim <= 32) ? 1u : WaveCount; + + // Total number of elements in a single I/O operation + static constexpr uint32_t ElementsPerIO + = Constants::AMDGCN_WAVE_SIZE * TestWidth * WaveCountFactor; + + // Total number of elements for the entire block + static constexpr uint32_t ElementCount = BlockDim * BlockK; + + // Ensure that for MaxVW: + // - A minimum of one IO from each wave can fit + // - A balanced multiple of IOs from each wave + static constexpr bool ElementCountTest + = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0); + + // Currently, all layouts are using ColOrthoVW. This means that VW must be less than BlockK + static constexpr bool LeadingDimTest = (TestWidth <= BlockK); + + // Decide on final MaxVW + static constexpr uint32_t MaxVectorWidth = (ElementCountTest && LeadingDimTest) + ? TestWidth + : MaxVWSelector::Result; public: - enum : uint32_t - { - Result = (uint32_t)MaxVectorWidth - }; + static constexpr uint32_t Result = MaxVectorWidth; }; + // Accumulator case, is architecture specific + template + struct MaxVWSelector + { + static_assert(WaveCount == 1u, "Accumulators are not cooperative"); + + constexpr static uint32_t Result + = (bool)ROCWMMA_ARCH_GFX12 + ? 8u + : ((is_same_v || (bool)ROCWMMA_ARCH_GFX11) ? 1u : 4u); + }; + + // Fallback case for bad test. Stay safe to VW=1 template struct MaxVWSelector { - enum : uint32_t - { - Result = 1u - }; + static constexpr uint32_t Result = 1u; }; } // namespace detail /*! \struct IOLayout - * \brief Definition of VW, MaxVW, data and matrix mapping utilities - * in specific matrix context. - * - * @tparam MatrixT fragment context - * @tparam BlockDim Block leading dimension - * @tparam BlockK Block K-dimension - * @tparam DataT data type - * @tparam DataLayoutT in-memory layout as col_major or row_major - * @tparam WaveCount number of cooperative waves - */ + * \brief Definition of VW, MaxVW, data and matrix mapping utilities + * in specific matrix context. + * + * @tparam MatrixT fragment context + * @tparam BlockDim Block leading dimension + * @tparam BlockK Block K-dimension + * @tparam DataT data type + * @tparam DataLayoutT in-memory layout as col_major or row_major + * @tparam WaveCount number of cooperative waves + */ template { // Vector size properties - enum : uint32_t - { - MaxVW = detail:: - MaxVWSelector::Result, - - VW = is_same::value || BlockDim > 32 ? MaxVW : 1u - }; + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; + constexpr static uint32_t VW + = is_same_v || BlockDim > 32u ? MaxVW : 1u; // DataLayout using DataLayout = DataLayout::template Array1d; // Matrix Layouts - // Layout profile for 'matrix_a': ColNT for small frags, Col for large frags + // Small dim mma friendly using SmallDimMatrixLayout = conditional_t, MatrixLayout::ColOrthoVW, MatrixLayout::ColOrthoVW>; + // Large dim not mma friendly using LargeDimMatrixLayout = conditional_t, MatrixLayout::ColInlineVW, MatrixLayout::ColOrthoVW>; using MatrixLayout - = conditional_t; + = conditional_t; - // Register layouts - using MemoryLayout = RegisterLayout::Storage; - using FragmentLayout = MemoryLayout; - using MmaLayout = RegisterLayout::MmaInput; + + // Register layout required for mma. Expect non-interleaved SOA format. + using MmaLayout = RegisterLayout::MmaInput; + ? RegisterLayout::Format::WMMA_INPUT_GFX11 + : RegisterLayout::Format::SOA>; + // Fragments will keep storage register layout. + // No post-load / pre-store xform + // May require pre-mma xform + using FragmentLayout = StorageLayout; }; template { // Vector size properties - enum : uint32_t - { - MaxVW = detail:: - MaxVWSelector::Result, - - VW = is_same::value || BlockDim > 32 ? MaxVW : 1u - }; + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; + constexpr static uint32_t VW + = is_same_v || BlockDim > 32 ? MaxVW : 1u; // DataLayout using DataLayout = DataLayout::template Array1d; // Matrix Layouts - // Layout profile for 'matrix_a': ColNT for small frags, Col for large frags + // Small dim mma friendly using SmallDimMatrixLayout = conditional_t, MatrixLayout::RowOrthoVW, MatrixLayout::RowOrthoVW>; + // Large dim not mma friendly using LargeDimMatrixLayout = conditional_t, MatrixLayout::RowInlineVW, @@ -200,14 +217,20 @@ namespace rocwmma using MatrixLayout = conditional_t; - // Register layouts - using MemoryLayout = RegisterLayout::Storage; - using FragmentLayout = MemoryLayout; - using MmaLayout = RegisterLayout::MmaInput; + + // Register layout required for mma. Expect non-interleaved SOA format. + using MmaLayout = RegisterLayout::MmaInput; + ? RegisterLayout::Format::WMMA_INPUT_GFX11 + : RegisterLayout::Format::SOA>; + + // Fragments will keep storage register layout. + // No post-load / pre-store xform + // May require pre-mma xform + using FragmentLayout = StorageLayout; }; template { // Vector size properties - enum : uint32_t - { - MaxVW = ROCWMMA_ARCH_GFX12 - ? 8u - : ((is_same::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u), - VW = is_same::value ? MaxVW : 1u - }; + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; + constexpr static uint32_t VW = is_same_v ? MaxVW : 1u; // DataLayout using DataLayout = DataLayout::template Array1d; - // Layout profile for 'accumulator' set to RowNT, small frags + // Always mma friendly using MatrixLayout = conditional_t, MatrixLayout::RowOrthoVW, MatrixLayout::RowOrthoVW>; - // Register layouts - using MemoryLayout = RegisterLayout::Storage; - using MmaLayout = RegisterLayout::MmaAcc; + + // Register layout required for mma. Expect non-interleaved SOA format. + using MmaLayout = RegisterLayout::MmaAcc; - using FragmentLayout = MemoryLayout; + ? RegisterLayout::Format::WMMA_ACC_GFX11 + : RegisterLayout::Format::SOA>; + + // Fragments will keep storage register layout. + // No post-load / pre-store xform + // May require pre-mma xform. + // TODO: Ideally, should really be MmaLayout + // However, MmaAcc frags are restricted to 16/32 MmaDim. + // Once restriction is lifted, should be adjusted. + using FragmentLayout = StorageLayout; }; template struct IOLayout { - using MemoryLayout = void; - using MmaLayout = RegisterLayout::MmaAcc; + ? RegisterLayout::Format::WMMA_ACC_GFX11 + : RegisterLayout::Format::SOA>; + + // Fragments will keep mma register layout. + // No pre-mma xform + using FragmentLayout = MmaLayout; + }; + + namespace detail + { + template + struct MmaDimSelector + { + private: + // Try to get the best interleaved VW along BlockDim axis. + static constexpr uint32_t SizeB128 = 128u >> 2u; + static constexpr uint32_t InterleaveVW = BlockDim / TestMmaDim; + static constexpr uint32_t BytesPerThread = InterleaveVW * sizeof(DataT); + + public: + static constexpr uint32_t Result = (BytesPerThread < SizeB128 ? 16u : TestMmaDim); + }; + + } // namespace detail + + /*! \struct IOLayoutInt + * \brief Definition of VW, MaxVW, data and matrix mapping utilities + * in specific matrix context. + * + * @tparam MatrixT fragment context + * @tparam BlockDim Block leading dimension + * @tparam BlockK Block K-dimension + * @tparam DataT data type + * @tparam DataLayoutT in-memory layout as col_major or row_major + * @tparam WaveCount number of cooperative waves + */ + template + struct IOLayoutInt; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::ColInlineInt, + MatrixLayout::ColOrthoInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect interleaved SOA format. + using MmaLayout = RegisterLayout::MmaInput; + + // Fragments will keep storage register layout. + // No post-load / pre-store xform + // May require pre-mma xform + using FragmentLayout = StorageLayout; + }; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::RowOrthoInt, + MatrixLayout::RowInlineInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect interleaved SOA format. + using MmaLayout = RegisterLayout::MmaInput; + // Fragments will keep storage register layout. + // No post-load / pre-store xform + // May require pre-mma xform + using FragmentLayout = StorageLayout; + }; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::RowOrthoInt, + MatrixLayout::RowInlineInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect interleaved accum format for multiple blocks. + using MmaLayout + = RegisterLayout::MmaAcc; + + // Fragments will keep mma register layout. + // May require post-load / pre-store xform + // No pre-mma xform + using FragmentLayout = MmaLayout; + }; + + template + struct IOLayoutInt + { + // We don't know which storage is needed: no DataLayout + using StorageLayout = void; + + // Register layout required for mma. Expect interleaved accum format for multiple blocks. + using MmaLayout + = RegisterLayout::MmaAcc; + + // Fragments will keep mma register layout. + // No pre-mma xform using FragmentLayout = MmaLayout; }; From d89b5368b2cd155fa8c9317d08caa3bf6d5bb924 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Fri, 22 Nov 2024 15:33:33 +0000 Subject: [PATCH 16/36] Fixes for interleaved layout compatibility --- .../include/rocwmma/internal/io_layout.hpp | 23 ++++++++-- .../layout/register_layout_transforms.hpp | 42 ++++++++++++++++++- .../rocwmma/internal/vector_util_impl.hpp | 2 +- library/include/rocwmma/rocwmma_impl.hpp | 13 +++--- 4 files changed, 66 insertions(+), 14 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index fcb4aecb..363489bd 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -352,7 +352,7 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect interleaved SOA format. - using MmaLayout = RegisterLayout::MmaInput::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; }; template ; // Register layout required for mma. Expect interleaved SOA format. - using MmaLayout = RegisterLayout::MmaInput::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; }; template ::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; }; template struct IOLayoutInt { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + // We don't know which storage is needed: no DataLayout using StorageLayout = void; // Register layout required for mma. Expect interleaved accum format for multiple blocks. using MmaLayout - = RegisterLayout::MmaAcc + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t + && (!traits_lhs::is_register_layout || !traits_rhs::is_register_layout + || !is_layout_orthogonal_v)>> + { + static_assert(0, "Register layout transform is not supported"); + }; + // Apply paths between orthogonal transforms template struct register_layout_transform< @@ -107,9 +118,38 @@ namespace rocwmma = conditional_t; return interleave<1u, storage_traits::KPerThread>(forward(v)); } + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + { + using storage_traits + = conditional_t; + return interleave<1u, storage_traits::KPerThread>(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + { + using storage_traits + = conditional_t; + + return interleave<1u, 4u>(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::AOS_INT) + { + using storage_traits + = conditional_t; + return interleave<1u, 4u>(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::SOA_INT) + { + using storage_traits + = conditional_t; + return interleave<1u, storage_traits::KPerThread>(forward(v)); + } else { - static_assert(0, "Shouldn't get here"); + static_assert(0, "Register layout transform is not implemented"); return v; } } diff --git a/library/include/rocwmma/internal/vector_util_impl.hpp b/library/include/rocwmma/internal/vector_util_impl.hpp index 467b858b..b9e9d1b1 100644 --- a/library/include/rocwmma/internal/vector_util_impl.hpp +++ b/library/include/rocwmma/internal/vector_util_impl.hpp @@ -434,7 +434,7 @@ namespace rocwmma static_assert((GatherSize >= 1u) && (GatherSize <= ElementStride) && (ElementStride % GatherSize == 0) && (VecSize % GatherSize == 0), "Invalid GatherSize"); - static_assert(ElementStride >= 1u && ElementStride < VecSize, "Invalid Stride"); + static_assert(ElementStride >= 1u && ElementStride <= VecSize, "Invalid Stride"); // No transform is needed (NOP) if constexpr(GatherSize == ElementStride || ElementStride == VecSize) diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index cf03d959..552c8871 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -355,15 +355,12 @@ namespace rocwmma && (IOConfigB::IOShape::BlockDim <= 32), "Input fragment BlockDim is not mfma friendly"); - static_assert(IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim, - "KDim of input fragments must match"); + static_assert((IOConfigA::IOShape::BlockDim == IOConfigB::IOShape::BlockDim) + && (IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim), + "BlockDim and KDim of input fragments must match"); - static_assert(is_layout_orthogonal_v, - "Input fragment matrix layouts are not orthogonal"); - - static_assert(is_layout_same_v, + static_assert(is_layout_same_v, "Input fragment register layouts do not match"); // Gfx9 uses MFMA, gfx11 uses WMMA From f6ff3e4945a4d502ba4b4a0daa2b78d1bcf566a4 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 25 Nov 2024 18:19:24 +0000 Subject: [PATCH 17/36] Add initial non-interleaved layout traits test --- .../include/rocwmma/internal/io_layout.hpp | 1 - .../internal/layout/layout_profile.hpp | 290 ---- test/unit/CMakeLists.txt | 3 +- test/unit/layout_traits_test/CMakeLists.txt | 34 + .../detail/layout_traits.hpp | 155 ++ .../device/layout_traits.hpp | 1405 +++++++++++++++++ .../test/layout_traits_16.cpp | 92 ++ 7 files changed, 1688 insertions(+), 292 deletions(-) delete mode 100644 library/include/rocwmma/internal/layout/layout_profile.hpp create mode 100644 test/unit/layout_traits_test/CMakeLists.txt create mode 100644 test/unit/layout_traits_test/detail/layout_traits.hpp create mode 100644 test/unit/layout_traits_test/device/layout_traits.hpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_16.cpp diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 363489bd..e79d0e14 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -29,7 +29,6 @@ #include "api_fwd.hpp" #include "constants.hpp" #include "layout/layout.hpp" -#include "layout/layout_profile.hpp" #include "types.hpp" namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/layout_profile.hpp b/library/include/rocwmma/internal/layout/layout_profile.hpp deleted file mode 100644 index 8ab2ec69..00000000 --- a/library/include/rocwmma/internal/layout/layout_profile.hpp +++ /dev/null @@ -1,290 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_LAYOUT_PROFILE_HPP -#define ROCWMMA_LAYOUT_PROFILE_HPP - -#include "layout.hpp" - -namespace rocwmma -{ - // Layout profiles are high-level objects that describe fragments in three mapped spaces: - // 1. DataLayout: data locality in 1D memory space (row_major or col_major) - // 2. MatrixLayout: data locality in 2D matrix space (ColOrthoVW, ColInlineVW, etc.) - // 3. RegisterLayout: data locality in register space (Storage, or MmaInput) - namespace LayoutProfile - { - // ColNT is a layout profile that has the following properties: - // 1. Leading dimension is aligned with column elements of fragment data: - // - BlockDim is assumed the column size, or BlockM dimension. - // - Analogous to capturing columns of 'matrix A'. - // 2. When BlockDim is supported by mma, register elements are always in MmaInput friendly register layout. - // 3. Register layout does NOT change whether DataLayout is col_major or row_major (free DataLayoutT change). - // 4. MatrixLayout will capture contiguous column elements across contiguous threads. - // 5. VectorWidth is fixed to 1 in col_major to ensure #4 (non-optimal). - template - struct ColNT - { - // Layouts - using DataLayout = DataLayout::template Array1d; - - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::ColOrthoVW, - MatrixLayout::ColOrthoVW>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // ColNT enforces consistent in-register alignment of contiguous matrix column - // elements in both row_major or col_major data layouts. - // This layout cannot support for VW > 1 in col_major data layout otherwise the - // ordering is broken. - static_assert(!(is_same_v && VectorWidth > 1), - "ColNT in col_major does not support VectorWidth > 1"); - - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // RowNT is a layout profile that has the following properties: - // 1. Leading dimension is aligned with row elements of fragment data: - // - BlockDim is assumed the row size, or BlockN dimension. - // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. When BlockDim is supported by mma, register elements are always MmaInput friendly register layout. - // 3. Register layout does NOT change whether DataLayout is col_major or row_major (fast DataLayoutT change). - // 4. MatrixLayout will capture contiguous row elements across contiguous threads. - // 5. VectorWidth is fixed to 1 in row_major to ensure #4 (non-optimal). - template - struct RowNT - { - // Layouts - using DataLayout = DataLayout::template Array1d; - - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::RowOrthoVW, - MatrixLayout::RowOrthoVW>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // RowNT enforces consistent in-register alignment of contiguous matrix row - // elements in both in row_major or col_major data layouts. - // This layout cannot support for VW > 1 in row_major data layout. - static_assert(!(is_same_v && VectorWidth > 1), - "RowNT in row_major does not support VectorWidth > 1"); - - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // Col is a layout profile that has the following properties: - // 1. Leading dimension is aligned with column elements of fragment data: - // - BlockDim is assumed the column size, or BlockM dimension. - // - Analogous to capturing columns of 'matrix A'. - // 2. Register layout is dynamic: - // - col_major data is stored in AOS register layout (non-MmaInput friendly), and - // - row_major data is stored in SOA register layout (MmaInput friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. Must convert to SOA if using with MFMA. - template - struct Col - { - // Layouts - using DataLayout = DataLayout::template Array1d; - - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::ColInlineVW, - MatrixLayout::ColOrthoVW>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // Row is a layout profile that has the following properties: - // 1. Leading dimension is aligned with row elements of fragment data: - // - BlockDim is assumed the row size, or BlockN dimension. - // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. Register layout is dynamic: - // - row_major data is stored in AOS register layout (non-MFMA friendly), and - // - col_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. - template - struct Row - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::RowInlineVW, - MatrixLayout::RowOrthoVW>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - //////////////// Interleaved layouts ///////////// - - // Col is a layout profile that has the following properties: - // 1. Leading dimension is aligned with column elements of fragment data: - // - BlockDim is assumed the column size, or BlockM dimension. - // - Analogous to capturing columns of 'matrix A'. - // 2. Register layout is dynamic: - // - col_major data is stored in AOS register layout (non-MFMA friendly), and - // - row_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. - template - struct ColInt - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::ColInlineInt, - MatrixLayout::ColOrthoInt>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - // TODO: fix - // static_assert( - // !(is_same_v && (MaxVectorWidth > BlockK)), - // "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // Row is a layout profile that has the following properties: - // 1. Leading dimension is aligned with row elements of fragment data: - // - BlockDim is assumed the row size, or BlockN dimension. - // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. Register layout is dynamic: - // - row_major data is stored in AOS register layout (non-MFMA friendly), and - // - col_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. - template - struct RowInt - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::RowInlineInt, - MatrixLayout::RowOrthoInt>; - - using RegisterLayout = RegisterLayout::Storage; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - // TODO: fix - // static_assert( - // !(is_same_v && (MaxVectorWidth > BlockK)), - // "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - } // namespace LayoutProfile - -} // namespace rocwmma - -#endif // ROCWMMA_LAYOUT_PROFILE_HPP diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 37709de6..a6183bca 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -2,7 +2,7 @@ # # MIT License # - # Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + # Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -51,6 +51,7 @@ endfunction() # Add unit tests add_subdirectory(contamination_test) add_subdirectory(layout_test) +add_subdirectory(layout_traits_test) add_subdirectory(map_util_test) add_subdirectory(load_store_matrix_sync_test) add_subdirectory(load_store_matrix_coop_sync_test) diff --git a/test/unit/layout_traits_test/CMakeLists.txt b/test/unit/layout_traits_test/CMakeLists.txt new file mode 100644 index 00000000..0850c465 --- /dev/null +++ b/test/unit/layout_traits_test/CMakeLists.txt @@ -0,0 +1,34 @@ +############################################################################### +# +# MIT License +# +# Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### + +# Include path for current test files +set(ROCWMMA_TEST_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR} ${ROCWMMA_TEST_INCLUDE_DIRS}) + +set(LayoutTraitsTestSources ${UnitCommonSources} + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_16.cpp + ) + +add_rocwmma_unit_test(layout_traits_test ${LayoutTraitsTestSources}) diff --git a/test/unit/layout_traits_test/detail/layout_traits.hpp b/test/unit/layout_traits_test/detail/layout_traits.hpp new file mode 100644 index 00000000..d7e1c4c8 --- /dev/null +++ b/test/unit/layout_traits_test/detail/layout_traits.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP +#define ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP + +#include "device/layout_traits.hpp" +#include "helper_macros.hpp" +#include "unit_kernel_base.hpp" + +namespace rocwmma +{ + + // Wrapper into the actual device function + template + struct LayoutTraitsKernel final : public UnitKernelBase + { + private: + using Base = UnitKernelBase; + + template + using TestGuard = FragSize_guard; + + public: + LayoutTraitsKernel() = default; + ~LayoutTraitsKernel() final = default; + + void setupImpl(typename Base::DataStorage::ProblemSize const& probsize) final + { + // Need at least 1 element for the result + auto& dataInstance = Base::DataStorage::instance(); + dataInstance->resizeStorage(probsize); + + dataInstance->hostOut().get()[0] = static_cast(ERROR_VALUE); + dataInstance->copyData(dataInstance->deviceOut(), dataInstance->hostOut(), 1); + + // Pass in warpSize from host to validate + Base::mParam1 = static_cast(Base::DeviceInfo::instance()->warpSize()); + } + + void validateResultsImpl() final + { + auto& dataInstance = Base::DataStorage::instance(); + + // Cache current kernel result from device + dataInstance->copyData(dataInstance->hostOut(), dataInstance->deviceOut(), 1); + + // Check the single output result + Base::mValidationResult = (dataInstance->hostOut().get()[0] == DataT(SUCCESS_VALUE)); + } + + bool checkQuirks() const final + { + auto waveSize = Base::DeviceInfo::instance()->warpSize(); + auto deviceArch = Base::DeviceInfo::instance()->getGcnArch(); + + // The test guard for this class requires 2 values at runtime. + auto dispatchGuard = [waveSize, deviceArch]() { + bool dispatchResult = false; + +#define CASE_IMPL_ASSIGN2(WAVE_SIZE, ARCH_ID) \ + dispatchResult = TestGuard::enable(); + +#define SWITCH_BODY_WAVE_SIZE(ARCH_ID) \ + ROCWMMA_SWITCH_BODY2_ARG2( \ + waveSize, CASE_IMPL_ASSIGN2, HipDevice::Wave32, HipDevice::Wave64, ARCH_ID) + +#define DISPATCH_GUARD_BODY \ + ROCWMMA_SWITCH_BODY10_ARG1(deviceArch, \ + SWITCH_BODY_WAVE_SIZE, \ + HipDevice::GFX908, \ + HipDevice::GFX90A, \ + HipDevice::GFX940, \ + HipDevice::GFX941, \ + HipDevice::GFX942, \ + HipDevice::GFX1100, \ + HipDevice::GFX1101, \ + HipDevice::GFX1102, \ + HipDevice::GFX1200, \ + HipDevice::GFX1201) + + DISPATCH_GUARD_BODY + +#undef CASE_IMPL_ASSIGN2 +#undef SWITCH_BODY_WAVE_SIZE +#undef DISPATCH_GUARD_BODY + + return dispatchResult; + }; + + return Base::checkQuirks() && dispatchGuard(); + } + + typename Base::KernelFunc kernelImpl() const final + { + return typename Base::KernelFunc(layoutTraitsTest); + } + }; + + // This is the GeneratorImpl class + struct LayoutTraitsGenerator + { + // Indices to test parameters + enum : uint32_t + { + BlockM = 0, + BlockN = 1, + DataT = 2, + DataLayoutT = 3 + }; + + using ResultT = std::shared_ptr; + + template + static ResultT generate(std::tuple testParams) + { + // Map GTest params to Kernel params + using TestParamsT = std::tuple; + using KernelT + = LayoutTraitsKernel::value, // BlockM + std::tuple_element_t::value, // BlockN + std::tuple_element_t, // DataT + std::tuple_element_t // DataLayout + >; + + return std::make_shared(); + } + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp new file mode 100644 index 00000000..ae9a15f6 --- /dev/null +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -0,0 +1,1405 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP +#define ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP + +#include + +#include "unit_test_traits.hpp" + +static constexpr uint32_t ERROR_VALUE = 7; +static constexpr uint32_t SUCCESS_VALUE = 0; + +namespace rocwmma +{ + + // template + // ROCWMMA_HOST bool testLayoutPair(const char* file, const char* line) + // { + // constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + // constexpr bool is_layout_orthogonal_result = rocwmma::is_layout_orthogonal_v; + // constexpr bool compare_result = ((is_layout_same_result == ExpectSame) && (is_layout_orthogonal_result == ExpectOrthogonal)); + + // if (DebugOnFail) + // { + // stream << "File: " << file << " L:" << line << std::endl; + // stream << "" << std::endl; + // stream << "Lhs: " << LayoutLhs{} << std::endl; + // stream << rocwmma::layout_traits{}; + // stream << "Rhs: " << LayoutRhs{} << std::endl; + // stream << rocwmma::layout_traits{}; + // stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame << std::endl; + // stream << "is_layout_orthogonal: " << is_layout_orthogonal_result << " Expected: " << ExpectOrthogonal << std::endl; + // stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; + // stream << "" << std::endl; + // } + + // return compare_result; + // } + + ROCWMMA_DEVICE inline bool isFirstThread() + { + return (threadIdx.x == 0) && (threadIdx.y == 0) && (threadIdx.z == 0) && (blockIdx.x == 0) + && (blockIdx.y == 0) && (blockIdx.z == 0); + } + + template + ROCWMMA_DEVICE bool testLayoutPair(const char* file, uint32_t line) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if(!compare_result && DebugOnFail && isFirstThread()) + { + printf("File: %s L:%d\n", file, line); + printf("\n"); + printf("is_layout_same: %d (Expected: %d)\n", is_layout_same_result, ExpectSame); + printf("is_layout_orthogonal: %d (Expected: %d)\n", + is_layout_orthogonal_result, + ExpectOrthogonal); + printf("%s\n", (compare_result ? "PASS" : "FAIL")); + printf("\n"); + } + + return compare_result; + } + +#define ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( \ + LayoutLhs, LayoutRhs, ExpectSame, ExpectOrthogonal, DebugOnFail) \ + testLayoutPair(__FILE__, \ + __LINE__); + + template + ROCWMMA_DEVICE bool dataLayoutTraitsTest() + { + constexpr bool debug_on_fail = true; + + using rocwmma::DataLayout::ColMajor; + using rocwmma::DataLayout::RowMajor; + + // DataLayouts are invariant of matrix layout properties + // Test both the meta tags and functional classes + using SameMeta = conditional_t, row_major, col_major>; + using OrthoMeta = conditional_t, col_major, row_major>; + using SameFunc = conditional_t, RowMajor, ColMajor>; + using OrthoFunc = conditional_t, ColMajor, RowMajor>; + + bool result = true; + + result + &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameMeta, true, false, debug_on_fail); + result + &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoMeta, false, true, debug_on_fail); + result + &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameFunc, true, false, debug_on_fail); + result + &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoFunc, false, true, debug_on_fail); + + return result; + } + + template + struct RegisterLayoutTestingSet + { + using ColInline = RegisterLayout::Storage< + MatrixLayout::ColInlineVW, + DataLayout>; + using ColOrtho = RegisterLayout::Storage< + MatrixLayout::ColOrthoVW, + DataLayout>; + using RowInline = RegisterLayout::Storage< + MatrixLayout::RowInlineVW, + DataLayout>; + using RowOrtho = RegisterLayout::Storage< + MatrixLayout::RowOrthoVW, + DataLayout>; + + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; + }; + + template + using MatrixLayout_t = typename layout_traits::MatrixLayout; + + template + ROCWMMA_DEVICE bool matrixLayoutTraitsTestNonInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Testing MatrixLayout properties + // MatrixLayouts are invariant to vector width + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set = RegisterLayoutTestingSet; + + bool result = true; + + // Matrix <-> Matrix layout + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + true, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + true, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + true, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + true, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + true, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + true, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + true, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + false, + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, + MatrixLayout_t, + true, + false, + debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool testNonInterleavedMma() + { + constexpr bool validMmaDim = LayoutTraits_impl::testSupportedMmaDim(); + constexpr bool validLayout = LayoutTraits_impl::testStorageLayoutIdentity(); + return validMmaDim && validLayout; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Storage <-> storage layout + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved1() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved2() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1u + // datalayout = same + constexpr uint32_t VectorWidth = 1u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved3() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1u + // datalayout = orthogonal + constexpr uint32_t VectorWidth = 1u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved4() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW0 = 1u + // VW1 = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth0 = 1u; + constexpr uint32_t VectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Already tested + if constexpr(VectorWidth0 == VectorWidth1) + { + return result; + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::ColOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::ColInline, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColInline, + (is_layout_same_v), + false, + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::RowOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowInline, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::RowInline, + (is_layout_same_v), + false, + debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved5() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW0 = 1u + // VW1 = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth0 = 1u; + constexpr uint32_t VectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + bool result = true; + + // Already tested + if constexpr(VectorWidth0 == VectorWidth1) + { + return result; + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::ColOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::ColInline, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColInline, + (is_layout_same_v), + false, + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::RowOrtho, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowInline, + (is_layout_same_v), + false, + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::RowInline, + (is_layout_same_v), + false, + debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved6() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1 + // MaxVW0 = 1 + // MaxVW1 = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth = 1u; + constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved7() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1 + // MaxVW0 = 1 + // MaxVW1 = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth = 1u; + constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved8() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Different BlockDim / BlockK + constexpr uint32_t VectorWidth = MaxVectorWidth; + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 16u ? 32u : 16u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 16u ? 32u : 16u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved9() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Different size DataT + constexpr uint32_t VectorWidth = MaxVectorWidth; + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int16_t, + conditional_t>>>; + + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved10() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Same size DataT + constexpr uint32_t VectorWidth = MaxVectorWidth; + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int8_t, + conditional_t>>>; + + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Already tested same type + if constexpr(is_same_v) + { + return result; + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, + typename Set1::ColOrtho, + false, + (is_layout_same_v), + debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, + typename Set1::ColInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, + typename Set1::RowOrtho, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, + typename Set1::RowInline, + false, + (is_layout_same_v), + debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( + typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + return result; + } + + template + ROCWMMA_DEVICE bool testBarrageNonInterleaved() + { + bool result = true; + result &= matrixLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved1(); + result &= registerLayoutTraitsTestNonInterleaved2(); + result &= registerLayoutTraitsTestNonInterleaved3(); + result &= registerLayoutTraitsTestNonInterleaved4(); + result &= registerLayoutTraitsTestNonInterleaved5(); + result &= registerLayoutTraitsTestNonInterleaved6(); + result &= registerLayoutTraitsTestNonInterleaved7(); + result &= registerLayoutTraitsTestNonInterleaved8(); + result &= registerLayoutTraitsTestNonInterleaved9(); + result &= registerLayoutTraitsTestNonInterleaved10(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestA() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockM; + constexpr uint32_t BlockK = BlockN; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestB() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestAcc() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestAccVoid() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + using DataLayoutT = void; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + //result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + __global__ void layoutTraitsTest(uint32_t m, + uint32_t n, + DataT const* in, + DataT* out, + uint32_t ld, + DataT param1, + DataT param2) + { + __shared__ int32_t result; + result = 0; + synchronize_workgroup(); + + bool success = true; + + success &= layoutTraitsTestA(); + success &= layoutTraitsTestB(); + success &= layoutTraitsTestAcc(); + + // Reduce error count + atomicAdd(&result, (int32_t)success); + + // Wait for all threads + synchronize_workgroup(); + + // Just need one thread to update output + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 + && blockIdx.y == 0 && blockIdx.z == 0) + { + out[0] = static_cast(result == 0 ? ERROR_VALUE : SUCCESS_VALUE); + } + } + +} // namespace rocwmma + +#endif // ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP diff --git a/test/unit/layout_traits_test/test/layout_traits_16.cpp b/test/unit/layout_traits_test/test/layout_traits_16.cpp new file mode 100644 index 00000000..97181f20 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_16.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = std::tuple; //typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes16; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsTest16 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsTest16, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsTest16, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); From 9cc7ebcc43b04c246b16a01f60961d80bc76bd89 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 25 Nov 2024 19:50:22 +0000 Subject: [PATCH 18/36] Add DataT to Mma layout interface. Add checks for data size comparison --- .../include/rocwmma/internal/io_layout.hpp | 8 + .../rocwmma/internal/layout/layout.hpp | 36 +- .../layout/register_layout_traits_impl.hpp | 56 +- .../device/layout_traits.hpp | 943 +++++------------- 4 files changed, 344 insertions(+), 699 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index e79d0e14..fc713c7a 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -174,6 +174,7 @@ namespace rocwmma // Register layout required for mma. Expect non-interleaved SOA format. using MmaLayout = RegisterLayout::MmaInput + typename DataT, + bool Interleaved, + Format Fmt = Interleaved ? Format::SOA_INT : Format::SOA> struct MmaInput { }; // A mnemonic used to describe the register layout is suitable for mma input for accumulator input/output template + typename DataT, + bool Interleaved, + Format Fmt = Interleaved ? Format::ACC_INT_A_MAJOR : Format::SOA> struct MmaAcc { }; @@ -239,20 +241,28 @@ namespace std return stream << "Storage<" << MatrixLayout{} << ", " << DataLayout{} << ">"; } - template + template inline ostream& operator<<( - ostream& stream, - rocwmma::RegisterLayout::MmaInput const& register_layout) + ostream& stream, + rocwmma::RegisterLayout::MmaInput const& register_layout) { - return stream << "MmaInput<" << MmaDim << ", " << Interleaved << ", " << Fmt << ">"; + return stream << "MmaInput<" << MmaDim << ", " << rocwmma::dataTypeToString() << ", " + << Interleaved << ", " << Fmt << ">"; } - template - inline ostream& - operator<<(ostream& stream, - rocwmma::RegisterLayout::MmaAcc const& register_layout) + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::MmaAcc const& register_layout) { - return stream << "MmaAcc<" << MmaDim << ", " << Interleaved << ", " << Fmt << ">"; + return stream << "MmaAcc<" << MmaDim << ", " << rocwmma::dataTypeToString() << ", " + << Interleaved << ", " << Fmt << ">"; } } // namespace std diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index def554c8..d4ad7962 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -55,13 +55,13 @@ namespace rocwmma { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; @@ -80,8 +80,8 @@ namespace rocwmma { }; - template - struct is_mma_input> : public true_type + template + struct is_mma_input> : public true_type { }; @@ -90,8 +90,8 @@ namespace rocwmma { }; - template - struct is_mma_acc> : public true_type + template + struct is_mma_acc> : public true_type { }; @@ -290,13 +290,19 @@ namespace rocwmma && testSupportedFormat>(); }; - template - struct register_layout_derived_traits> + template + struct register_layout_derived_traits< + MmaInput> : public matrix_layout_traits, public data_layout_traits { using MatrixLayout = void; using DataLayout = void; + using DataT = LayoutDataT; + // Overrides constexpr static bool is_interleaved = LayoutIsInterleaved; constexpr static uint32_t MmaDim = LayoutMmaDim; @@ -305,17 +311,25 @@ namespace rocwmma constexpr static RegisterLayout::Format Format = Fmt; constexpr static bool is_valid - = testSupportedMmaDim>() - && testSupportedFormat>(); + = testSupportedMmaDim< + MmaInput>() + && testSupportedFormat< + MmaInput>(); }; - template - struct register_layout_derived_traits> + template + struct register_layout_derived_traits< + MmaAcc> : public matrix_layout_traits, public data_layout_traits { using MatrixLayout = void; using DataLayout = void; + using DataT = LayoutDataT; + // Overrides constexpr static bool is_interleaved = LayoutIsInterleaved; constexpr static uint32_t MmaDim = LayoutMmaDim; @@ -324,8 +338,9 @@ namespace rocwmma constexpr static RegisterLayout::Format Format = Fmt; constexpr static bool is_valid - = testSupportedMmaDim>() - && testSupportedFormat>(); + = testSupportedMmaDim>() + && testSupportedFormat< + MmaAcc>(); }; // Combine base instance traits with specific layout classifiers @@ -369,9 +384,12 @@ namespace rocwmma // Matching MmaDim, interleaving and validity // Note: matching validity does not imply valid! // Cannot mix valid with invalid layouts - constexpr bool BaseTest = (traits_lhs::MmaDim == traits_rhs::MmaDim) - && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) - && (traits_lhs::is_valid == traits_rhs::is_valid); + // Datatype must have same size for same register layout + constexpr bool BaseTest + = (traits_lhs::MmaDim == traits_rhs::MmaDim) + && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) + && (traits_lhs::is_valid == traits_rhs::is_valid) + && (sizeof(typename traits_lhs::DataT) == sizeof(typename traits_rhs::DataT)); // MmaInput <-> MmaInput // MmaAcc <-> MmaAcc diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp index ae9a15f6..9b0813c7 100644 --- a/test/unit/layout_traits_test/device/layout_traits.hpp +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -117,14 +117,12 @@ namespace rocwmma bool result = true; - result - &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameMeta, true, false, debug_on_fail); - result - &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoMeta, false, true, debug_on_fail); - result - &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameFunc, true, false, debug_on_fail); - result - &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoFunc, false, true, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameMeta, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoMeta, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameFunc, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoFunc, false, true, debug_on_fail); + // clang-format on return result; } @@ -150,8 +148,8 @@ namespace rocwmma MatrixLayout::RowOrthoVW, DataLayout>; - using MmaInput = RegisterLayout::MmaInput; - using MmaAcc = RegisterLayout::MmaAcc; + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; }; template @@ -179,86 +177,27 @@ namespace rocwmma bool result = true; // Matrix <-> Matrix layout - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - true, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - true, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - true, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - true, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - true, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - true, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - true, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - false, - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, - MatrixLayout_t, - true, - false, - debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + // clang-format on return result; } @@ -300,53 +239,26 @@ namespace rocwmma bool result = true; // Storage <-> storage layout - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts @@ -370,6 +282,8 @@ namespace rocwmma // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + // clang-format on + return result; } @@ -401,53 +315,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format on return result; } @@ -480,41 +368,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format on return result; } @@ -547,41 +421,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format on return result; } @@ -622,89 +482,27 @@ namespace rocwmma return result; } - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::ColOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::ColInline, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColInline, - (is_layout_same_v), - false, - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::RowOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowInline, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::RowInline, - (is_layout_same_v), - false, - debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + // clang-format on return result; } @@ -745,89 +543,27 @@ namespace rocwmma return result; } - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::ColOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::ColInline, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColInline, - (is_layout_same_v), - false, - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::RowOrtho, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowInline, - (is_layout_same_v), - false, - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::RowInline, - (is_layout_same_v), - false, - debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + // clang-format on return result; } @@ -864,41 +600,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format on return result; } @@ -935,41 +657,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format on return result; } @@ -1007,41 +715,27 @@ namespace rocwmma bool result = true; - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format on return result; } @@ -1091,41 +785,27 @@ namespace rocwmma return result; } - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + // clang-format on return result; } @@ -1175,53 +855,27 @@ namespace rocwmma return result; } - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, - typename Set1::ColOrtho, - false, - (is_layout_same_v), - debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, - typename Set1::ColInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, - typename Set1::RowOrtho, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, - typename Set1::RowInline, - false, - (is_layout_same_v), - debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( - typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + // clang-format on return result; } @@ -1234,66 +888,21 @@ namespace rocwmma ROCWMMA_DEVICE bool testBarrageNonInterleaved() { bool result = true; - result &= matrixLayoutTraitsTestNonInterleaved0(); - result &= registerLayoutTraitsTestNonInterleaved0(); - result &= registerLayoutTraitsTestNonInterleaved1(); - result &= registerLayoutTraitsTestNonInterleaved2(); - result &= registerLayoutTraitsTestNonInterleaved3(); - result &= registerLayoutTraitsTestNonInterleaved4(); - result &= registerLayoutTraitsTestNonInterleaved5(); - result &= registerLayoutTraitsTestNonInterleaved6(); - result &= registerLayoutTraitsTestNonInterleaved7(); - result &= registerLayoutTraitsTestNonInterleaved8(); - result &= registerLayoutTraitsTestNonInterleaved9(); - result &= registerLayoutTraitsTestNonInterleaved10(); + + // clang-format off + result &= matrixLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved1(); + result &= registerLayoutTraitsTestNonInterleaved2(); + result &= registerLayoutTraitsTestNonInterleaved3(); + result &= registerLayoutTraitsTestNonInterleaved4(); + result &= registerLayoutTraitsTestNonInterleaved5(); + result &= registerLayoutTraitsTestNonInterleaved6(); + result &= registerLayoutTraitsTestNonInterleaved7(); + result &= registerLayoutTraitsTestNonInterleaved8(); + result &= registerLayoutTraitsTestNonInterleaved9(); + result &= registerLayoutTraitsTestNonInterleaved10(); + // clang-format on return result; } From 368c6dc41882591e8751df8bc17db17423b659c7 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 27 Nov 2024 00:14:03 +0000 Subject: [PATCH 19/36] Fixes f64 tests. Adds all block sizes tests. --- .../layout/register_layout_traits_impl.hpp | 14 +- test/unit/layout_traits_test/CMakeLists.txt | 4 + .../device/layout_traits.hpp | 568 ++++++++++++++---- .../test/layout_traits_128.cpp | 92 +++ .../test/layout_traits_16.cpp | 2 +- .../test/layout_traits_256.cpp | 92 +++ .../test/layout_traits_32.cpp | 92 +++ .../test/layout_traits_64.cpp | 92 +++ 8 files changed, 823 insertions(+), 133 deletions(-) create mode 100644 test/unit/layout_traits_test/test/layout_traits_128.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_256.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_32.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_64.cpp diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index d4ad7962..0c60a82f 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -175,8 +175,9 @@ namespace rocwmma ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaDim() { using traits = register_layout_traits; - return ((bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED && traits::MmaDim == 16u) - || ((bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED && traits::MmaDim == 32u); + return (traits::MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (traits::MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); // f64 mfma only 16 } // Based on the current architecture, which register layout formats currently supported @@ -228,8 +229,7 @@ namespace rocwmma return traits::is_storage && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) || (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT) - || (traits::Format == Format::Invalid)); + || (traits::Format == Format::AOS_INT)); } } @@ -384,12 +384,12 @@ namespace rocwmma // Matching MmaDim, interleaving and validity // Note: matching validity does not imply valid! // Cannot mix valid with invalid layouts - // Datatype must have same size for same register layout + // Datatype must be same constexpr bool BaseTest = (traits_lhs::MmaDim == traits_rhs::MmaDim) && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) && (traits_lhs::is_valid == traits_rhs::is_valid) - && (sizeof(typename traits_lhs::DataT) == sizeof(typename traits_rhs::DataT)); + && (is_same_v); // MmaInput <-> MmaInput // MmaAcc <-> MmaAcc @@ -466,7 +466,7 @@ namespace rocwmma // Special case: interleaved layouts // Check matching thread dims and if either one is == 1u. - // Register contents will be identical, regardless if the format matches. + // Register contents will be identical, regardless of different formats. constexpr bool TestIdentityQuirks = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); diff --git a/test/unit/layout_traits_test/CMakeLists.txt b/test/unit/layout_traits_test/CMakeLists.txt index 0850c465..7c33c88d 100644 --- a/test/unit/layout_traits_test/CMakeLists.txt +++ b/test/unit/layout_traits_test/CMakeLists.txt @@ -29,6 +29,10 @@ set(ROCWMMA_TEST_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR} ${ROCWMMA_TEST_INCLUDE set(LayoutTraitsTestSources ${UnitCommonSources} ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_128.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_256.cpp ) add_rocwmma_unit_test(layout_traits_test ${LayoutTraitsTestSources}) diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp index 9b0813c7..c76ec791 100644 --- a/test/unit/layout_traits_test/device/layout_traits.hpp +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -37,29 +37,37 @@ static constexpr uint32_t SUCCESS_VALUE = 0; namespace rocwmma { - // template - // ROCWMMA_HOST bool testLayoutPair(const char* file, const char* line) - // { - // constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; - // constexpr bool is_layout_orthogonal_result = rocwmma::is_layout_orthogonal_v; - // constexpr bool compare_result = ((is_layout_same_result == ExpectSame) && (is_layout_orthogonal_result == ExpectOrthogonal)); - - // if (DebugOnFail) - // { - // stream << "File: " << file << " L:" << line << std::endl; - // stream << "" << std::endl; - // stream << "Lhs: " << LayoutLhs{} << std::endl; - // stream << rocwmma::layout_traits{}; - // stream << "Rhs: " << LayoutRhs{} << std::endl; - // stream << rocwmma::layout_traits{}; - // stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame << std::endl; - // stream << "is_layout_orthogonal: " << is_layout_orthogonal_result << " Expected: " << ExpectOrthogonal << std::endl; - // stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; - // stream << "" << std::endl; - // } - - // return compare_result; - // } + template + ROCWMMA_HOST bool testLayoutPair(const char* file, const char* line, std::ostream& stream) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if(DebugOnFail) + { + stream << "File: " << file << " L:" << line << std::endl; + stream << "" << std::endl; + stream << "Lhs: " << LayoutLhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "Rhs: " << LayoutRhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame + << std::endl; + stream << "is_layout_orthogonal: " << is_layout_orthogonal_result + << " Expected: " << ExpectOrthogonal << std::endl; + stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; + stream << "" << std::endl; + } + + return compare_result; + } ROCWMMA_DEVICE inline bool isFirstThread() { @@ -202,12 +210,33 @@ namespace rocwmma return result; } - template - ROCWMMA_DEVICE bool testNonInterleavedMma() + template + ROCWMMA_DEVICE constexpr bool testRowMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testColMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testMmaDim() { - constexpr bool validMmaDim = LayoutTraits_impl::testSupportedMmaDim(); - constexpr bool validLayout = LayoutTraits_impl::testStorageLayoutIdentity(); - return validMmaDim && validLayout; + return (MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); + } + + template + ROCWMMA_DEVICE constexpr bool testMmaAccVW() + { + return MaxVectorWidth + == ((bool)ROCWMMA_ARCH_GFX12 + ? 8u + : ((is_same_v || (bool)ROCWMMA_ARCH_GFX11) ? 1u : 4u)); } template ; + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + bool result = true; + // Case is tested in #2 + if constexpr(VectorWidth == 1u) + { + return result; + } + // Storage <-> storage layout // clang-format off result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); - + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -313,15 +359,31 @@ namespace rocwmma MaxVectorWidth, rocwmma::orthogonal_layout_t>; + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + bool result = true; + // Case is tested in #3 + if constexpr(VectorWidth == 1u) + { + return result; + } + // clang-format off result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); @@ -329,12 +391,36 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -366,6 +452,9 @@ namespace rocwmma MaxVectorWidth, DataLayoutT>; + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + bool result = true; // clang-format off @@ -388,6 +477,30 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -419,6 +532,9 @@ namespace rocwmma MaxVectorWidth, rocwmma::orthogonal_layout_t>; + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + bool result = true; // clang-format off @@ -441,6 +557,30 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -474,34 +614,68 @@ namespace rocwmma MaxVectorWidth, DataLayoutT>; + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + bool result = true; - // Already tested + // Case tested in #0,1,2,3 if constexpr(VectorWidth0 == VectorWidth1) { return result; } // clang-format off - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -535,34 +709,68 @@ namespace rocwmma MaxVectorWidth, rocwmma::orthogonal_layout_t>; + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + bool result = true; - // Already tested + // Case tested in #0,1,2,3 if constexpr(VectorWidth0 == VectorWidth1) { return result; } // clang-format off - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, (is_layout_same_v), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, (is_layout_same_v), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -598,6 +806,10 @@ namespace rocwmma MaxVectorWidth1, DataLayoutT>; + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw0 = testMmaAccVW(); + constexpr bool is_acc_vw1 = testMmaAccVW(); + bool result = true; // clang-format off @@ -620,6 +832,30 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -655,6 +891,10 @@ namespace rocwmma MaxVectorWidth1, rocwmma::orthogonal_layout_t>; + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw0 = testMmaAccVW(); + constexpr bool is_acc_vw1 = testMmaAccVW(); + bool result = true; // clang-format off @@ -677,6 +917,30 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -697,9 +961,9 @@ namespace rocwmma // Different BlockDim / BlockK constexpr uint32_t VectorWidth = MaxVectorWidth; constexpr uint32_t BlockDim0 = BlockDim; - constexpr uint32_t BlockDim1 = BlockDim == 16u ? 32u : 16u; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; constexpr uint32_t BlockK0 = BlockK; - constexpr uint32_t BlockK1 = BlockK == 16u ? 32u : 16u; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; using Set0 = RegisterLayoutTestingSet mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -805,6 +1093,30 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -856,25 +1168,49 @@ namespace rocwmma } // clang-format off - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, (is_layout_same_v), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, (is_layout_same_v), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; @@ -958,24 +1294,6 @@ namespace rocwmma return result; } - template - ROCWMMA_DEVICE bool layoutTraitsTestAccVoid() - { - // TODO: WaveCount - constexpr uint32_t WaveCount = 1u; - constexpr uint32_t BlockDim = BlockN; - constexpr uint32_t BlockK = BlockM; - using DataLayoutT = void; - constexpr uint32_t MaxVW = rocwmma::detail:: - MaxVWSelector::Result; - - bool result = true; - //result &= dataLayoutTraitsTest(); - result &= testBarrageNonInterleaved(); - - return result; - } - template __global__ void layoutTraitsTest(uint32_t m, uint32_t n, diff --git a/test/unit/layout_traits_test/test/layout_traits_128.cpp b/test/unit/layout_traits_test/test/layout_traits_128.cpp new file mode 100644 index 00000000..013547f7 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_128.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes128; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsTest128 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsTest128, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsTest128, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); diff --git a/test/unit/layout_traits_test/test/layout_traits_16.cpp b/test/unit/layout_traits_test/test/layout_traits_16.cpp index 97181f20..24b1cc04 100644 --- a/test/unit/layout_traits_test/test/layout_traits_16.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_16.cpp @@ -36,7 +36,7 @@ namespace rocwmma struct TestParams : public UnitTestParams { using Base = UnitTestParams; - using Types = std::tuple; //typename Base::TestAllSizeTypes; + using Types = typename Base::TestAllSizeTypes; using BlockSizes = typename Base::TestBlockSizes16; using DataLayouts = typename Base::TestLayoutsAll; using KernelParams = typename CombineLists::Result; diff --git a/test/unit/layout_traits_test/test/layout_traits_256.cpp b/test/unit/layout_traits_test/test/layout_traits_256.cpp new file mode 100644 index 00000000..6fe6d716 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_256.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes256; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsTest256 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsTest256, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsTest256, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); diff --git a/test/unit/layout_traits_test/test/layout_traits_32.cpp b/test/unit/layout_traits_test/test/layout_traits_32.cpp new file mode 100644 index 00000000..eba1359a --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_32.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes32; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsTest32 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsTest32, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsTest32, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); diff --git a/test/unit/layout_traits_test/test/layout_traits_64.cpp b/test/unit/layout_traits_test/test/layout_traits_64.cpp new file mode 100644 index 00000000..662a97db --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_64.cpp @@ -0,0 +1,92 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes64; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsTest64 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsTest64, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsTest64, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); From d88e35344718838622bbeb56a9b46a3195fec2ad Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 27 Nov 2024 04:21:23 +0000 Subject: [PATCH 20/36] Start implementing interleaved layout traits tests --- .../detail/layout_traits_int.hpp | 165 +++ .../device/layout_traits_int.hpp | 1313 +++++++++++++++++ .../test/layout_traits_int_16.cpp | 95 ++ 3 files changed, 1573 insertions(+) create mode 100644 test/unit/layout_traits_test/detail/layout_traits_int.hpp create mode 100644 test/unit/layout_traits_test/device/layout_traits_int.hpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_int_16.cpp diff --git a/test/unit/layout_traits_test/detail/layout_traits_int.hpp b/test/unit/layout_traits_test/detail/layout_traits_int.hpp new file mode 100644 index 00000000..4de13641 --- /dev/null +++ b/test/unit/layout_traits_test/detail/layout_traits_int.hpp @@ -0,0 +1,165 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_LAYOUT_TRAITS_INT_TEST_DETAIL_HPP +#define ROCWMMA_LAYOUT_TRAITS_INT_TEST_DETAIL_HPP + +#include "device/layout_traits_int.hpp" +#include "helper_macros.hpp" +#include "unit_kernel_base.hpp" + +namespace rocwmma +{ + + // Wrapper into the actual device function + template + struct LayoutTraitsIntKernel final : public UnitKernelBase + { + private: + using Base = UnitKernelBase; + + template + using TestGuard = FragSize_guard; + + public: + LayoutTraitsIntKernel() = default; + ~LayoutTraitsIntKernel() final = default; + + void setupImpl(typename Base::DataStorage::ProblemSize const& probsize) final + { + // Need at least 1 element for the result + auto& dataInstance = Base::DataStorage::instance(); + dataInstance->resizeStorage(probsize); + + dataInstance->hostOut().get()[0] = static_cast(ERROR_VALUE); + dataInstance->copyData(dataInstance->deviceOut(), dataInstance->hostOut(), 1); + + // Pass in warpSize from host to validate + Base::mParam1 = static_cast(Base::DeviceInfo::instance()->warpSize()); + } + + void validateResultsImpl() final + { + auto& dataInstance = Base::DataStorage::instance(); + + // Cache current kernel result from device + dataInstance->copyData(dataInstance->hostOut(), dataInstance->deviceOut(), 1); + + // Check the single output result + Base::mValidationResult = (dataInstance->hostOut().get()[0] == DataT(SUCCESS_VALUE)); + } + + bool checkQuirks() const final + { + auto waveSize = Base::DeviceInfo::instance()->warpSize(); + auto deviceArch = Base::DeviceInfo::instance()->getGcnArch(); + + // The test guard for this class requires 2 values at runtime. + auto dispatchGuard = [waveSize, deviceArch]() { + bool dispatchResult = false; + +#define CASE_IMPL_ASSIGN2(WAVE_SIZE, ARCH_ID) \ + dispatchResult = TestGuard::enable(); + +#define SWITCH_BODY_WAVE_SIZE(ARCH_ID) \ + ROCWMMA_SWITCH_BODY2_ARG2( \ + waveSize, CASE_IMPL_ASSIGN2, HipDevice::Wave32, HipDevice::Wave64, ARCH_ID) + +#define DISPATCH_GUARD_BODY \ + ROCWMMA_SWITCH_BODY10_ARG1(deviceArch, \ + SWITCH_BODY_WAVE_SIZE, \ + HipDevice::GFX908, \ + HipDevice::GFX90A, \ + HipDevice::GFX940, \ + HipDevice::GFX941, \ + HipDevice::GFX942, \ + HipDevice::GFX1100, \ + HipDevice::GFX1101, \ + HipDevice::GFX1102, \ + HipDevice::GFX1200, \ + HipDevice::GFX1201) + + DISPATCH_GUARD_BODY + +#undef CASE_IMPL_ASSIGN2 +#undef SWITCH_BODY_WAVE_SIZE +#undef DISPATCH_GUARD_BODY + + return dispatchResult; + }; + + return Base::checkQuirks() && dispatchGuard(); + } + + typename Base::KernelFunc kernelImpl() const final + { + return typename Base::KernelFunc( + layoutTraitsIntTest); + } + }; + + // This is the GeneratorImpl class + struct LayoutTraitsIntGenerator + { + // Indices to test parameters + enum : uint32_t + { + BlockM = 0, + BlockN = 1, + DataT = 2, + DataLayoutT = 3, + MmaDim = 4, + SplitK = 5, + }; + + using ResultT = std::shared_ptr; + + template + static ResultT generate(std::tuple testParams) + { + // Map GTest params to Kernel params + using TestParamsT = std::tuple; + using KernelT = LayoutTraitsIntKernel< + std::tuple_element_t::value, // BlockM + std::tuple_element_t::value, // BlockN + std::tuple_element_t, // DataT + std::tuple_element_t, // DataLayout + std::tuple_element_t::value, // MmaDim + std::tuple_element_t::value // SplitK + >; + + return std::make_shared(); + } + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp new file mode 100644 index 00000000..1dc0b5be --- /dev/null +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -0,0 +1,1313 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_DEVICE_LAYOUT_TRAITS_INT_TEST_HPP +#define ROCWMMA_DEVICE_LAYOUT_TRAITS_INT_TEST_HPP + +#include + +#include "unit_test_traits.hpp" + +static constexpr uint32_t ERROR_VALUE = 7; +static constexpr uint32_t SUCCESS_VALUE = 0; + +namespace rocwmma +{ + template + ROCWMMA_HOST bool + testLayoutPair(const char* file, const char* line, std::ostream& stream = std::cout) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if constexpr(DebugOnFail) + { + stream << "File: " << file << " L:" << line << std::endl; + stream << "" << std::endl; + stream << "Lhs: " << LayoutLhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "Rhs: " << LayoutRhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame + << std::endl; + stream << "is_layout_orthogonal: " << is_layout_orthogonal_result + << " Expected: " << ExpectOrthogonal << std::endl; + stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; + stream << "" << std::endl; + } + + return compare_result; + } + + ROCWMMA_DEVICE inline bool isFirstThread() + { + return (threadIdx.x == 0) && (threadIdx.y == 0) && (threadIdx.z == 0) && (blockIdx.x == 0) + && (blockIdx.y == 0) && (blockIdx.z == 0); + } + + template + ROCWMMA_DEVICE bool testLayoutPair(const char* file, uint32_t line) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if(!compare_result && DebugOnFail && isFirstThread()) + { + printf("File: %s L:%d\n", file, line); + printf("\n"); + printf("is_layout_same: %d (Expected: %d)\n", is_layout_same_result, ExpectSame); + printf("is_layout_orthogonal: %d (Expected: %d)\n", + is_layout_orthogonal_result, + ExpectOrthogonal); + printf("%s\n", (compare_result ? "PASS" : "FAIL")); + printf("\n"); + } + + return compare_result; + } + +#define ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( \ + LayoutLhs, LayoutRhs, ExpectSame, ExpectOrthogonal, DebugOnFail) \ + testLayoutPair(__FILE__, \ + __LINE__); + + template + struct RegisterLayoutIntTestingSet + { + using ColInline = RegisterLayout::Storage< + MatrixLayout::ColInlineInt, + DataLayoutT>; + using ColOrtho = RegisterLayout::Storage< + MatrixLayout::ColOrthoInt, + DataLayoutT>; + using RowInline = RegisterLayout::Storage< + MatrixLayout::RowInlineInt, + DataLayoutT>; + using RowOrtho = RegisterLayout::Storage< + MatrixLayout::RowOrthoInt, + DataLayoutT>; + + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; + }; + + template + using MatrixLayout_t = typename layout_traits::MatrixLayout; + + template + ROCWMMA_DEVICE bool matrixLayoutTraitsTestInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Testing MatrixLayout properties + // MatrixLayouts are invariant to vector width + using Set + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Matrix <-> Matrix layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE constexpr bool testRowMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testColMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testMmaDim() + { + return (MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); + } + + template + ROCWMMA_DEVICE constexpr uint32_t dimPerThread() + { + return BlockDim / MmaDim; + } + + template + ROCWMMA_DEVICE constexpr uint32_t kPerThread() + { + return BlockK * MmaDim / (WaveSize * SplitK); + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + // Checks identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = (dpt == 1u) || (kpt == 1u); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + if constexpr(!is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved1() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = MaxVW + // // datalayout = orthogonal + // constexpr uint32_t VectorWidth = MaxVectorWidth; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet>; + + // constexpr bool is_row_mjr = testRowMajor(); + // constexpr bool is_col_mjr = testColMajor(); + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw = testMmaAccVW(); + + // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + // bool result = true; + + // // Case is tested in #3 + // if constexpr(VectorWidth == 1u) + // { + // return result; + // } + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved2() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = 1u + // // datalayout = same + // constexpr uint32_t VectorWidth = 1u; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw = testMmaAccVW(); + + // bool result = true; + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved3() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = 1u + // // datalayout = orthogonal + // constexpr uint32_t VectorWidth = 1u; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet>; + + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw = testMmaAccVW(); + + // bool result = true; + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved4() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW0 = 1u + // // VW1 = MaxVW + // // datalayout = same + // constexpr uint32_t VectorWidth0 = 1u; + // constexpr uint32_t VectorWidth1 = MaxVectorWidth; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // constexpr bool is_row_mjr = testRowMajor(); + // constexpr bool is_col_mjr = testColMajor(); + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw = testMmaAccVW(); + + // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + // bool result = true; + + // // Case tested in #0,1,2,3 + // if constexpr(VectorWidth0 == VectorWidth1) + // { + // return result; + // } + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved5() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW0 = 1u + // // VW1 = MaxVW + // // datalayout = orthogonal + // constexpr uint32_t VectorWidth0 = 1u; + // constexpr uint32_t VectorWidth1 = MaxVectorWidth; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet>; + + // constexpr bool is_row_mjr = testRowMajor(); + // constexpr bool is_col_mjr = testColMajor(); + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw = testMmaAccVW(); + + // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + // bool result = true; + + // // Case tested in #0,1,2,3 + // if constexpr(VectorWidth0 == VectorWidth1) + // { + // return result; + // } + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved6() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = 1 + // // MaxVW0 = 1 + // // MaxVW1 = MaxVW + // // datalayout = same + // constexpr uint32_t VectorWidth = 1u; + // constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + // constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw0 = testMmaAccVW(); + // constexpr bool is_acc_vw1 = testMmaAccVW(); + + // bool result = true; + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved7() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = 1 + // // MaxVW0 = 1 + // // MaxVW1 = MaxVW + // // datalayout = orthogonal + // constexpr uint32_t VectorWidth = 1u; + // constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + // constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet>; + + // constexpr bool is_mma_dim = testMmaDim(); + // constexpr bool is_acc_vw0 = testMmaAccVW(); + // constexpr bool is_acc_vw1 = testMmaAccVW(); + + // bool result = true; + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved8() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = MaxVW + // // datalayout = same + // // Different BlockDim / BlockK + // constexpr uint32_t VectorWidth = MaxVectorWidth; + // constexpr uint32_t BlockDim0 = BlockDim; + // constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + // constexpr uint32_t BlockK0 = BlockK; + // constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // bool result = true; + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved9() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = MaxVW + // // datalayout = same + // // Different size DataT + // constexpr uint32_t VectorWidth = MaxVectorWidth; + // using DataT0 = DataT; + // using DataT1 = conditional_t< + // sizeof(DataT) == 1u, + // int16_t, + // conditional_t>>>; + + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // bool result = true; + + // // Already checked same types + // if constexpr(is_same_v) + // { + // return result; + // } + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + // template + // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved10() + // { + // constexpr bool debug_on_fail = true; + + // // Non-interleaved + // // VW = MaxVW + // // datalayout = same + // // Same size DataT + // constexpr uint32_t VectorWidth = MaxVectorWidth; + // using DataT0 = DataT; + // using DataT1 = conditional_t< + // sizeof(DataT) == 1u, + // int8_t, + // conditional_t>>>; + + // using Set0 = RegisterLayoutTestingSet; + // using Set1 = RegisterLayoutTestingSet; + + // bool result = true; + + // // Already tested same type + // if constexpr(is_same_v) + // { + // return result; + // } + + // // clang-format off + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // // Storage <-> mma layouts + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // // clang-format on + + // return result; + // } + + template + ROCWMMA_DEVICE bool testBarrageInterleaved() + { + bool result = true; + + // clang-format off + result &= matrixLayoutTraitsTestInterleaved0(); + result &= registerLayoutTraitsTestInterleaved0(); + // result &= registerLayoutTraitsTestInterleaved1(); + // result &= registerLayoutTraitsTestInterleaved2(); + // result &= registerLayoutTraitsTestInterleaved3(); + // result &= registerLayoutTraitsTestInterleaved4(); + // result &= registerLayoutTraitsTestInterleaved5(); + // result &= registerLayoutTraitsTestInterleaved6(); + // result &= registerLayoutTraitsTestInterleaved7(); + // result &= registerLayoutTraitsTestInterleaved8(); + // result &= registerLayoutTraitsTestInterleaved9(); + // result &= registerLayoutTraitsTestInterleaved10(); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestA() + { + constexpr uint32_t BlockDim = BlockM; + constexpr uint32_t BlockK = BlockN; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestB() + { + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestAcc() + { + // TODO: WaveCount + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + __global__ void layoutTraitsIntTest(uint32_t m, + uint32_t n, + DataT const* in, + DataT* out, + uint32_t ld, + DataT param1, + DataT param2) + { + __shared__ int32_t result; + result = 0; + synchronize_workgroup(); + + bool success = true; + + success &= layoutTraitsTestA(); + success &= layoutTraitsTestB(); + success &= layoutTraitsTestAcc(); + + // Reduce error count + atomicAdd(&result, (int32_t)success); + + // Wait for all threads + synchronize_workgroup(); + + // Just need one thread to update output + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 + && blockIdx.y == 0 && blockIdx.z == 0) + { + out[0] = static_cast(result == 0 ? 7 : 0); + } + } + +} // namespace rocwmma + +#endif // ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP diff --git a/test/unit/layout_traits_test/test/layout_traits_int_16.cpp b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp new file mode 100644 index 00000000..5c5b1f12 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "kernel_generator.hpp" +#include "unit_test.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = std::tuple; //typename Base::TestAllSizeTypes; + using MmaDims = std::tuple>; //std::tuple, I<32>>; + using SplitKs = std::tuple>; //std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes16; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +// Test suite for unique parameterization +class LayoutTraitsIntTest16 : public rocwmma::UnitTest +{ +}; + +TEST_P(LayoutTraitsIntTest16, RunKernel) +{ + this->RunKernel(); +} + +INSTANTIATE_TEST_SUITE_P( + KernelTests, + LayoutTraitsIntTest16, + ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), + ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), + ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), + ::testing::ValuesIn(rocwmma::TestParams::param1s()), + ::testing::ValuesIn(rocwmma::TestParams::param2s()))); From 31e2f5acd1406a61fc772d8136bc887bf6a865c3 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 28 Nov 2024 18:58:49 +0000 Subject: [PATCH 21/36] Add interleaved and emulation tests --- test/unit/layout_traits_test/CMakeLists.txt | 13 + .../device/layout_traits.hpp | 5 +- .../device/layout_traits_int.hpp | 1970 +++++++++-------- .../test/common_includes.hpp | 36 + .../emulation/extendedtest_layout_traits.cpp | 77 + .../extendedtest_layout_traits_int.cpp | 81 + .../regressiontest_layout_traits.cpp | 76 + .../regressiontest_layout_traits_int.cpp | 80 + .../emulation/smoketest_layout_traits.cpp | 74 + .../emulation/smoketest_layout_traits_int.cpp | 78 + .../test/layout_traits_128.cpp | 22 +- .../test/layout_traits_16.cpp | 22 +- .../test/layout_traits_256.cpp | 22 +- .../test/layout_traits_32.cpp | 22 +- .../test/layout_traits_64.cpp | 22 +- .../test/layout_traits_int_128.cpp | 78 + .../test/layout_traits_int_16.cpp | 29 +- .../test/layout_traits_int_256.cpp | 78 + .../test/layout_traits_int_32.cpp | 78 + .../test/layout_traits_int_64.cpp | 78 + 20 files changed, 1899 insertions(+), 1042 deletions(-) create mode 100644 test/unit/layout_traits_test/test/common_includes.hpp create mode 100644 test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits.cpp create mode 100644 test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp create mode 100644 test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp create mode 100644 test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp create mode 100644 test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp create mode 100644 test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_int_128.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_int_256.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_int_32.cpp create mode 100644 test/unit/layout_traits_test/test/layout_traits_int_64.cpp diff --git a/test/unit/layout_traits_test/CMakeLists.txt b/test/unit/layout_traits_test/CMakeLists.txt index 7c33c88d..de94d7e7 100644 --- a/test/unit/layout_traits_test/CMakeLists.txt +++ b/test/unit/layout_traits_test/CMakeLists.txt @@ -33,6 +33,19 @@ set(LayoutTraitsTestSources ${UnitCommonSources} ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_64.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_128.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_256.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_128.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_256.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/smoketest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/smoketest_layout_traits_int.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/regressiontest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/regressiontest_layout_traits_int.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/extendedtest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/extendedtest_layout_traits_int.cpp ) add_rocwmma_unit_test(layout_traits_test ${LayoutTraitsTestSources}) diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp index c76ec791..35bdba98 100644 --- a/test/unit/layout_traits_test/device/layout_traits.hpp +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -42,7 +42,8 @@ namespace rocwmma bool ExpectSame, bool ExpectOrthogonal, bool DebugOnFail> - ROCWMMA_HOST bool testLayoutPair(const char* file, const char* line, std::ostream& stream) + ROCWMMA_HOST bool + testLayoutPair(const char* file, const char* line, std::ostream& stream = std::cout) { constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; constexpr bool is_layout_orthogonal_result @@ -50,7 +51,7 @@ namespace rocwmma constexpr bool compare_result = ((is_layout_same_result == ExpectSame) && (is_layout_orthogonal_result == ExpectOrthogonal)); - if(DebugOnFail) + if constexpr(DebugOnFail) { stream << "File: " << file << " L:" << line << std::endl; stream << "" << std::endl; diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp index 1dc0b5be..b4323457 100644 --- a/test/unit/layout_traits_test/device/layout_traits_int.hpp +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -129,8 +129,8 @@ namespace rocwmma MatrixLayout::RowOrthoInt, DataLayoutT>; - using MmaInput = RegisterLayout::MmaInput; - using MmaAcc = RegisterLayout::MmaAcc; + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; }; template @@ -239,7 +239,106 @@ namespace rocwmma constexpr bool is_row_mjr = testRowMajor(); constexpr bool is_col_mjr = testColMajor(); - constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_row_mjr_dpt_1 = is_row_mjr && is_dpt_eq_1; + constexpr bool is_row_mjr_kpt_1 = is_row_mjr && is_kpt_eq_1; + constexpr bool is_col_mjr_dpt_1 = is_col_mjr && is_dpt_eq_1; + constexpr bool is_col_mjr_kpt_1 = is_col_mjr && is_kpt_eq_1; + + bool result = true; + + if constexpr(!is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1 ), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)) , false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_dpt_eq_1, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_dpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved1() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + // Checks non-identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); constexpr uint32_t dpt = dimPerThread(); constexpr uint32_t kpt = kPerThread(); @@ -252,7 +351,7 @@ namespace rocwmma bool result = true; - if constexpr(!is_id_quirk) + if constexpr(is_id_quirk) { return result; } @@ -260,45 +359,45 @@ namespace rocwmma // Storage <-> storage layout // clang-format off result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -307,888 +406,108 @@ namespace rocwmma return result; } - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved1() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = MaxVW - // // datalayout = orthogonal - // constexpr uint32_t VectorWidth = MaxVectorWidth; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet>; - - // constexpr bool is_row_mjr = testRowMajor(); - // constexpr bool is_col_mjr = testColMajor(); - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw = testMmaAccVW(); - - // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; - // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; - // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; - // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; - - // bool result = true; - - // // Case is tested in #3 - // if constexpr(VectorWidth == 1u) - // { - // return result; - // } - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved2() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = 1u - // // datalayout = same - // constexpr uint32_t VectorWidth = 1u; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw = testMmaAccVW(); - - // bool result = true; - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved3() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = 1u - // // datalayout = orthogonal - // constexpr uint32_t VectorWidth = 1u; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet>; - - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw = testMmaAccVW(); - - // bool result = true; - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved4() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW0 = 1u - // // VW1 = MaxVW - // // datalayout = same - // constexpr uint32_t VectorWidth0 = 1u; - // constexpr uint32_t VectorWidth1 = MaxVectorWidth; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // constexpr bool is_row_mjr = testRowMajor(); - // constexpr bool is_col_mjr = testColMajor(); - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw = testMmaAccVW(); - - // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; - // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; - // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; - // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; - - // bool result = true; - - // // Case tested in #0,1,2,3 - // if constexpr(VectorWidth0 == VectorWidth1) - // { - // return result; - // } - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved5() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW0 = 1u - // // VW1 = MaxVW - // // datalayout = orthogonal - // constexpr uint32_t VectorWidth0 = 1u; - // constexpr uint32_t VectorWidth1 = MaxVectorWidth; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet>; - - // constexpr bool is_row_mjr = testRowMajor(); - // constexpr bool is_col_mjr = testColMajor(); - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw = testMmaAccVW(); - - // constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; - // constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; - // constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; - // constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; - - // bool result = true; - - // // Case tested in #0,1,2,3 - // if constexpr(VectorWidth0 == VectorWidth1) - // { - // return result; - // } - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved6() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = 1 - // // MaxVW0 = 1 - // // MaxVW1 = MaxVW - // // datalayout = same - // constexpr uint32_t VectorWidth = 1u; - // constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; - // constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw0 = testMmaAccVW(); - // constexpr bool is_acc_vw1 = testMmaAccVW(); - - // bool result = true; - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved7() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = 1 - // // MaxVW0 = 1 - // // MaxVW1 = MaxVW - // // datalayout = orthogonal - // constexpr uint32_t VectorWidth = 1u; - // constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; - // constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet>; - - // constexpr bool is_mma_dim = testMmaDim(); - // constexpr bool is_acc_vw0 = testMmaAccVW(); - // constexpr bool is_acc_vw1 = testMmaAccVW(); - - // bool result = true; - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved8() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = MaxVW - // // datalayout = same - // // Different BlockDim / BlockK - // constexpr uint32_t VectorWidth = MaxVectorWidth; - // constexpr uint32_t BlockDim0 = BlockDim; - // constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; - // constexpr uint32_t BlockK0 = BlockK; - // constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // bool result = true; - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved9() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = MaxVW - // // datalayout = same - // // Different size DataT - // constexpr uint32_t VectorWidth = MaxVectorWidth; - // using DataT0 = DataT; - // using DataT1 = conditional_t< - // sizeof(DataT) == 1u, - // int16_t, - // conditional_t>>>; - - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // bool result = true; - - // // Already checked same types - // if constexpr(is_same_v) - // { - // return result; - // } - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } - - // template - // ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved10() - // { - // constexpr bool debug_on_fail = true; - - // // Non-interleaved - // // VW = MaxVW - // // datalayout = same - // // Same size DataT - // constexpr uint32_t VectorWidth = MaxVectorWidth; - // using DataT0 = DataT; - // using DataT1 = conditional_t< - // sizeof(DataT) == 1u, - // int8_t, - // conditional_t>>>; - - // using Set0 = RegisterLayoutTestingSet; - // using Set1 = RegisterLayoutTestingSet; - - // bool result = true; - - // // Already tested same type - // if constexpr(is_same_v) - // { - // return result; - // } - - // // clang-format off - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); - - // // Storage <-> mma layouts - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); - - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); - // result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); - // // clang-format on - - // return result; - // } + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved2() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = orthogonal + // MmaDim = same + // SplitK = same + // Checks identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet, + MmaDim, + SplitK>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_row_mjr_dpt_1 = is_row_mjr && is_dpt_eq_1; + constexpr bool is_row_mjr_kpt_1 = is_row_mjr && is_kpt_eq_1; + constexpr bool is_col_mjr_dpt_1 = is_col_mjr && is_dpt_eq_1; + constexpr bool is_col_mjr_kpt_1 = is_col_mjr && is_kpt_eq_1; + + bool result = true; + + if constexpr(!is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true , false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_dpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); // Can be invalid in same way + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_dpt_eq_1, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } template - ROCWMMA_DEVICE bool testBarrageInterleaved() + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved3() { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = orthogonal + // MmaDim = same + // SplitK = same + // Checks non-identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet, + MmaDim, + SplitK>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + bool result = true; + if constexpr(is_id_quirk) + { + return result; + } + + // Storage <-> storage layout // clang-format off - result &= matrixLayoutTraitsTestInterleaved0(); - result &= registerLayoutTraitsTestInterleaved0(); - // result &= registerLayoutTraitsTestInterleaved1(); - // result &= registerLayoutTraitsTestInterleaved2(); - // result &= registerLayoutTraitsTestInterleaved3(); - // result &= registerLayoutTraitsTestInterleaved4(); - // result &= registerLayoutTraitsTestInterleaved5(); - // result &= registerLayoutTraitsTestInterleaved6(); - // result &= registerLayoutTraitsTestInterleaved7(); - // result &= registerLayoutTraitsTestInterleaved8(); - // result &= registerLayoutTraitsTestInterleaved9(); - // result &= registerLayoutTraitsTestInterleaved10(); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); // Can be invalid in same way + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); // clang-format on return result; } + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved4() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = different + // BlockK = different + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + bool result = true; + + // Current test case deals with quirks validation + if constexpr((is_id_quirk0 != is_id_quirk1) || !is_id_quirk0) + { + return result; + } + + // Ensure MmaDim layout constraints are met + if constexpr(BlockDim0 >= MmaDim && BlockDim1 >= MmaDim) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // Same MmaDim + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved5() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = different + // BlockK = different + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Other test case deals with quirks validation + if constexpr((is_id_quirk0 != is_id_quirk1) || is_id_quirk0) + { + return result; + } + + // Ensure MmaDim layout constraints are met + if constexpr(BlockDim0 >= MmaDim && BlockDim1 >= MmaDim) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // Same MmaDim + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved6() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = different, same size + // DataLayoutT = same + // MmaDim = same + // SplitK = same + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int8_t, + conditional_t>>>; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved7() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = different, different size + // DataLayoutT = same + // MmaDim = same + // SplitK = same + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int16_t, + conditional_t>>>; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved8() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = different, valid + // SplitK = same + constexpr uint32_t MmaDim0 = MmaDim; + constexpr uint32_t MmaDim1 = MmaDim == 16 ? 32u : 16u; + + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + bool result = true; + + // Ensure MmaDim layout constraints are met + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + using Set0 = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet; + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved9() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = different + constexpr uint32_t SplitK0 = SplitK; + constexpr uint32_t SplitK1 = SplitK == 4u ? 2u : 4u; + + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Current test deals with quirk case validation + if constexpr((is_id_quirk0 != is_id_quirk1) || !is_id_quirk0) + { + return result; + } + + // Ensure layout requirements are satisfied + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved10() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = different + constexpr uint32_t SplitK0 = SplitK; + constexpr uint32_t SplitK1 = SplitK == 4u ? 2u : 4u; + + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Other test handles quirk case validation + if constexpr((is_id_quirk0 != is_id_quirk1) || is_id_quirk0) + { + return result; + } + + // Ensure layout requirements are satisfied + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool testBarrageInterleaved() + { + bool result = true; + + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // Must satisfy layout requirement + if constexpr(dpt > 0u && kpt > 0u) + { + // clang-format off + result &= matrixLayoutTraitsTestInterleaved0(); + result &= registerLayoutTraitsTestInterleaved0(); + result &= registerLayoutTraitsTestInterleaved1(); + result &= registerLayoutTraitsTestInterleaved2(); + result &= registerLayoutTraitsTestInterleaved3(); + result &= registerLayoutTraitsTestInterleaved4(); + result &= registerLayoutTraitsTestInterleaved5(); + result &= registerLayoutTraitsTestInterleaved6(); + result &= registerLayoutTraitsTestInterleaved7(); + result &= registerLayoutTraitsTestInterleaved8(); + result &= registerLayoutTraitsTestInterleaved9(); + result &= registerLayoutTraitsTestInterleaved10(); + // clang-format on + } + + return result; + } + template + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>, + std::tuple, I<128u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationExtendedLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp new file mode 100644 index 00000000..165a8b08 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32u>>; + using SplitKs = std::tuple, I<2u>, I<4u>>; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>, + std::tuple, I<128u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationExtendedLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp new file mode 100644 index 00000000..dbd84026 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp @@ -0,0 +1,76 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationRegressionLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp new file mode 100644 index 00000000..71d30a9e --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32u>>; + using SplitKs = std::tuple, I<2u>, I<4u>>; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationRegressionLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp new file mode 100644 index 00000000..603dfdb8 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, std::tuple, I<32u>>>; + using DataLayouts = std::tuple; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationSmokeLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp new file mode 100644 index 00000000..602a2a3d --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple>; + using SplitKs = std::tuple>; + using BlockSizes = std::tuple, I<16u>>, std::tuple, I<32u>>>; + using DataLayouts = std::tuple; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationSmokeLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_128.cpp b/test/unit/layout_traits_test/test/layout_traits_128.cpp index 013547f7..4f0f580b 100644 --- a/test/unit/layout_traits_test/test/layout_traits_128.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_128.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -72,21 +71,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsTest128 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsTest128, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsTest128, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest128, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_16.cpp b/test/unit/layout_traits_test/test/layout_traits_16.cpp index 24b1cc04..d97bca1a 100644 --- a/test/unit/layout_traits_test/test/layout_traits_16.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_16.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -72,21 +71,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsTest16 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsTest16, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsTest16, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest16, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_256.cpp b/test/unit/layout_traits_test/test/layout_traits_256.cpp index 6fe6d716..1917a999 100644 --- a/test/unit/layout_traits_test/test/layout_traits_256.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_256.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -72,21 +71,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsTest256 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsTest256, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsTest256, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest256, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_32.cpp b/test/unit/layout_traits_test/test/layout_traits_32.cpp index eba1359a..58c4b016 100644 --- a/test/unit/layout_traits_test/test/layout_traits_32.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_32.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -72,21 +71,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsTest32 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsTest32, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsTest32, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest32, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_64.cpp b/test/unit/layout_traits_test/test/layout_traits_64.cpp index 662a97db..e961e12d 100644 --- a/test/unit/layout_traits_test/test/layout_traits_64.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_64.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -72,21 +71,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsTest64 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsTest64, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsTest64, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest64, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_128.cpp b/test/unit/layout_traits_test/test/layout_traits_int_128.cpp new file mode 100644 index 00000000..459ff6a8 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_128.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes128; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest128, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_16.cpp b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp index 5c5b1f12..5e708a5c 100644 --- a/test/unit/layout_traits_test/test/layout_traits_int_16.cpp +++ b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp @@ -26,9 +26,8 @@ #include +#include "common_includes.hpp" #include "detail/layout_traits_int.hpp" -#include "kernel_generator.hpp" -#include "unit_test.hpp" namespace rocwmma { @@ -36,11 +35,12 @@ namespace rocwmma struct TestParams : public UnitTestParams { using Base = UnitTestParams; - using Types = std::tuple; //typename Base::TestAllSizeTypes; - using MmaDims = std::tuple>; //std::tuple, I<32>>; - using SplitKs = std::tuple>; //std::tuple, I<2>, I<4>>; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple>; + using SplitKs = std::tuple, I<2>, I<4>>; using BlockSizes = typename Base::TestBlockSizes16; using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; @@ -75,21 +75,4 @@ namespace rocwmma } // namespace rocwmma -// Test suite for unique parameterization -class LayoutTraitsIntTest16 : public rocwmma::UnitTest -{ -}; - -TEST_P(LayoutTraitsIntTest16, RunKernel) -{ - this->RunKernel(); -} - -INSTANTIATE_TEST_SUITE_P( - KernelTests, - LayoutTraitsIntTest16, - ::testing::Combine(::testing::ValuesIn(rocwmma::TestParams::kernels()), - ::testing::ValuesIn(rocwmma::TestParams::threadBlocks()), - ::testing::ValuesIn(rocwmma::TestParams::problemSizes()), - ::testing::ValuesIn(rocwmma::TestParams::param1s()), - ::testing::ValuesIn(rocwmma::TestParams::param2s()))); +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest16, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_256.cpp b/test/unit/layout_traits_test/test/layout_traits_int_256.cpp new file mode 100644 index 00000000..a50a44bf --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_256.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes256; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest256, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_32.cpp b/test/unit/layout_traits_test/test/layout_traits_int_32.cpp new file mode 100644 index 00000000..5f222865 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_32.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes32; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest32, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_64.cpp b/test/unit/layout_traits_test/test/layout_traits_int_64.cpp new file mode 100644 index 00000000..3de30ed2 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_64.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes64; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest64, TestParams); From ef05816fc3accfcfd4df8bf71e1b4b9da342b824 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 28 Nov 2024 20:46:47 +0000 Subject: [PATCH 22/36] Fix build of layout unit tests --- test/unit/layout_test/device/col_layout.hpp | 22 ++++++++++++++----- test/unit/layout_test/device/colnt_layout.hpp | 21 ++++++++++++------ test/unit/layout_test/device/row_layout.hpp | 8 ++++--- test/unit/layout_test/device/rownt_layout.hpp | 7 +++--- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/test/unit/layout_test/device/col_layout.hpp b/test/unit/layout_test/device/col_layout.hpp index aee9b1c3..a5d8bc9a 100644 --- a/test/unit/layout_test/device/col_layout.hpp +++ b/test/unit/layout_test/device/col_layout.hpp @@ -52,15 +52,25 @@ namespace rocwmma { enum : uint32_t { + BlockHeight = BlockM, + BlockWidth = BlockN, + + BlockDim = BlockM, + KDim = BlockN, + MaxVectorWidth - = detail::MaxVWSelector::Result, + = detail::MaxVWSelector::Result, VectorWidth = std::is_same_v ? MaxVectorWidth : 1 }; - using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile:: - Col::MatrixLayout; - using Mapping = MappingUtil; + using IOTraits = IOTraits; + + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::ColInlineVW, + MatrixLayout::ColOrthoVW>; + + using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); auto iocount = IOTraits::IOCount; @@ -74,7 +84,7 @@ namespace rocwmma for(uint32_t i = 0; i < iocount; ++i) { - for(int j = 0; j < VectorWidth; j++) + for(uint32_t j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) + Mapping::dataOffset(baseOffset, ld) + j; diff --git a/test/unit/layout_test/device/colnt_layout.hpp b/test/unit/layout_test/device/colnt_layout.hpp index e19c1165..84b40b19 100644 --- a/test/unit/layout_test/device/colnt_layout.hpp +++ b/test/unit/layout_test/device/colnt_layout.hpp @@ -52,16 +52,23 @@ namespace rocwmma { enum : uint32_t { + BlockHeight = BlockM, + BlockWidth = BlockN, + + BlockDim = BlockM, + KDim = BlockN, + MaxVectorWidth - = detail::MaxVWSelector::Result, + = detail::MaxVWSelector::Result, VectorWidth = std::is_same_v ? MaxVectorWidth : 1 }; - using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile:: - ColNT:: - MatrixLayout; - using Mapping = MappingUtil; + using IOTraits = IOTraits; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::ColOrthoVW, + MatrixLayout::ColOrthoVW>; + using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); auto iocount = IOTraits::IOCount; @@ -75,7 +82,7 @@ namespace rocwmma for(uint32_t i = 0; i < iocount; ++i) { - for(int j = 0; j < VectorWidth; j++) + for(uint32_t j = 0; j < VectorWidth; j++) { auto index = (get(matrixCoord) * ld + get(matrixCoord)) + Mapping::dataOffset(baseOffset, ld) + j; diff --git a/test/unit/layout_test/device/row_layout.hpp b/test/unit/layout_test/device/row_layout.hpp index bd7d2106..ea9a4898 100644 --- a/test/unit/layout_test/device/row_layout.hpp +++ b/test/unit/layout_test/device/row_layout.hpp @@ -60,13 +60,15 @@ namespace rocwmma KDim = BlockM, MaxVectorWidth - = detail::MaxVWSelector::Result, + = detail::MaxVWSelector::Result, VectorWidth = std::is_same_v ? MaxVectorWidth : 1 }; using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile:: - Row::MatrixLayout; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::RowInlineVW, + MatrixLayout::RowOrthoVW>; using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); diff --git a/test/unit/layout_test/device/rownt_layout.hpp b/test/unit/layout_test/device/rownt_layout.hpp index afc2fab2..c19c9b99 100644 --- a/test/unit/layout_test/device/rownt_layout.hpp +++ b/test/unit/layout_test/device/rownt_layout.hpp @@ -67,9 +67,10 @@ namespace rocwmma }; using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile:: - RowNT:: - MatrixLayout; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::RowOrthoVW, + MatrixLayout::RowOrthoVW>; using Mapping = MappingUtil; auto baseOffset = LayoutT::baseOffset(); From b87547189f95be200911d88477dd96e44873be38 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 2 Dec 2024 15:15:18 -0700 Subject: [PATCH 23/36] Fix gfx11 implementation --- .../layout/register_layout_traits_impl.hpp | 128 ++++++++++-------- .../layout/register_layout_transforms.hpp | 15 ++ library/include/rocwmma/internal/wmma.hpp | 28 +--- library/include/rocwmma/rocwmma_impl.hpp | 3 +- 4 files changed, 92 insertions(+), 82 deletions(-) diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 0c60a82f..019338ff 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -55,13 +55,13 @@ namespace rocwmma { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; - template - struct is_register_layout> : public true_type + template + struct is_register_layout> : public true_type { }; @@ -80,8 +80,8 @@ namespace rocwmma { }; - template - struct is_mma_input> : public true_type + template + struct is_mma_input> : public true_type { }; @@ -90,8 +90,8 @@ namespace rocwmma { }; - template - struct is_mma_acc> : public true_type + template + struct is_mma_acc> : public true_type { }; @@ -186,50 +186,68 @@ namespace rocwmma { using traits = register_layout_traits; using rocwmma::RegisterLayout::Format; - if constexpr(traits::is_mma_input) + if constexpr((bool)ROCWMMA_ARCH_GFX11) { - if constexpr(traits::is_interleaved) + if constexpr(traits::is_mma_input) + { + return traits::Format == Format::WMMA_INPUT_GFX11; + } + else if constexpr(traits::is_mma_acc) { - return (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT); + if constexpr(traits::is_interleaved) + { + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) + || (traits::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11); + } + else + { + return (traits::Format == Format::WMMA_ACC_GFX11); + } } else { - return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + return traits::is_storage + && ((traits::Format == Format::SOA) + || (traits::Format == Format::AOS) + || (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT)); } } - else if constexpr(traits::is_mma_acc) + else // Other archs { -#if ROCWMMA_ARCH_GFX11 - if constexpr(traits::is_interleaved) + if constexpr(traits::is_mma_input) { - // Intermediate accumulation format for interleaved layout - return (traits::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) - || (traits::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11); + if constexpr(traits::is_interleaved) + { + return (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT); + } + else + { + return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + } } - else + else if constexpr(traits::is_mma_acc) { - return (traits::Format == WMMA_ACC_GFX11); - } -#else - if constexpr(traits::is_interleaved) - { - // Intermediate accumulation format for interleaved layout - return (traits::Format == Format::ACC_INT_A_MAJOR) - || (traits::Format == Format::ACC_INT_B_MAJOR); + if constexpr(traits::is_interleaved) + { + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::ACC_INT_A_MAJOR) + || (traits::Format == Format::ACC_INT_B_MAJOR); + } + else + { + return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + } } else { - return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); + return traits::is_storage + && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) + || (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT)); } -#endif // ROCWMMA_ARCH_GFX11 - } - else - { - return traits::is_storage - && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) - || (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT)); } } @@ -494,31 +512,25 @@ namespace rocwmma // ACC_INT_A_MAJOR <-> AOS, SOA // ACC_INT_B_MAJOR <-> AOS, SOA // Register layouts must be valid to be orthogonal + // clang-format off using RegisterLayout::Format; constexpr bool TestOpposingFormat - = ((traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) + = ( (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) - || (traits_lhs::Format == Format::SOA_INT - && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::AOS_INT - && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_A_MAJOR - && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_A_MAJOR - && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA_INT - && traits_rhs::Format == Format::ACC_INT_A_MAJOR) - || (traits_lhs::Format == Format::AOS_INT - && traits_rhs::Format == Format::ACC_INT_A_MAJOR) - || (traits_lhs::Format == Format::ACC_INT_B_MAJOR - && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_B_MAJOR - && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA_INT - && traits_rhs::Format == Format::ACC_INT_B_MAJOR) - || (traits_lhs::Format == Format::AOS_INT - && traits_rhs::Format == Format::ACC_INT_B_MAJOR)) + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR)) && (traits_lhs::is_valid && traits_rhs::is_valid); + // clang-format on return TestNotSame && TestCompatibleParams && TestOpposingFormat; } diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 0e7e010d..62233661 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -147,6 +147,21 @@ namespace rocwmma = conditional_t; return interleave<1u, storage_traits::KPerThread>(forward(v)); } + else if constexpr((traits_lhs::Format == Format::SOA || traits_lhs::Format == Format::AOS) + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + // Input is unpacked + using VecTraits = VecTraits>; + using PackUtil = PackUtil; + + // Swap upper / lower 16's and then concatenate them + // to make sure we have each K value in each half. + // GFX11 wmma layout quirk needs the duplication. + auto packed = PackUtil::pack(v); + auto swapped = Swizzle::Swap16::exec(packed); + auto result = PackUtil::unpack(concat(packed, swapped)); + return result; // Return by copy + } else { static_assert(0, "Register layout transform is not implemented"); diff --git a/library/include/rocwmma/internal/wmma.hpp b/library/include/rocwmma/internal/wmma.hpp index 7308b5a1..74525c20 100644 --- a/library/include/rocwmma/internal/wmma.hpp +++ b/library/include/rocwmma/internal/wmma.hpp @@ -129,9 +129,9 @@ namespace rocwmma exec(InputARegsT const& regsA, InputBRegsT const& regsB, InputCRegsT const& regsC) { // Inputs from outside will come in as fully packed - static_assert(VecTraits::size() == IOTraitsA::PackedSize, + static_assert(VecTraits::size() == VecTraitsA::size() * Traits::WmmaCount, "WMMA input size mismatch"); - static_assert(VecTraits::size() == IOTraitsB::PackedSize, + static_assert(VecTraits::size() == VecTraitsA::size() * Traits::WmmaCount, "WMMA input size mismatch"); static_assert(VecTraits::size() == IOTraitsAcc::PackedSize, "WMMA input size mismatch"); @@ -144,32 +144,14 @@ namespace rocwmma auto accum = PackUtil::template pad(PackUtil::unpack(regsC)); // Iterate over packed WMMA inputs - auto const aIt - = makeVectorIterator(regsA).begin(); - auto const bIt - = makeVectorIterator(regsB).begin(); + auto const aIt = makeVectorIterator(regsA).begin(); + auto const bIt = makeVectorIterator(regsB).begin(); // Accumulate over WMMA count #pragma unroll for(uint32_t i = 0; i < Traits::WmmaCount; i++) { -#if ROCWMMA_ARCH_GFX11 - // Swap upper / lower 16 elements - auto swappedA = Swizzle::Swap16::exec(*aIt); - auto swappedB = Swizzle::Swap16::exec(*bIt); - - // Combine duplicated data for mult/accum. - // Evens: non-swapped - // Odds: swapped - accum = WMMA::exec(concat(unpackLo(*aIt, swappedA), unpackHi(*aIt, swappedA)), - concat(unpackLo(*bIt, swappedB), unpackHi(*bIt, swappedB)), - accum); -#else - accum = WMMA::exec(*aIt, *bIt, accum); - -#endif - aIt++; bIt++; } @@ -182,4 +164,4 @@ namespace rocwmma } // namespace rocwmma -#endif // ROCWMMA_WMMA_HPP +#endif // ROCWMMA_WMMA_HPP \ No newline at end of file diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index 552c8871..4968764e 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -368,7 +368,8 @@ namespace rocwmma Mfma, Wmma>; - // mma functions operate on packed vectors + // Operate pre-ops on unpacked vectors + // the pack for mma inputs (*d) = MMA::exec(PackA::pack(PreMmaA::exec(a.mAccess)), PackB::pack(PreMmaB::exec(b.mAccess)), PackAcc::pack(PreMmaAcc::exec(c.mAccess))); From b03528947aa90b3ce5309a2f2427b2eb0534219a Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 2 Dec 2024 15:15:37 -0700 Subject: [PATCH 24/36] Restore perf_hgemm --- samples/perf_hgemm.cpp | 658 +++++++++-------------------------------- 1 file changed, 142 insertions(+), 516 deletions(-) diff --git a/samples/perf_hgemm.cpp b/samples/perf_hgemm.cpp index fd469567..766f68b8 100644 --- a/samples/perf_hgemm.cpp +++ b/samples/perf_hgemm.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -251,11 +251,11 @@ namespace gfx11Params }; } -//#if(ROCWMMA_ARCH_GFX9) +#if(ROCWMMA_ARCH_GFX9) using namespace gfx9Params; -//#else -//using namespace gfx11Params; -//#endif // defined(ROCWMMA_ARCH_GFX9) +#else +using namespace gfx11Params; +#endif // defined(ROCWMMA_ARCH_GFX9) /// /// Types and Data Layouts @@ -358,474 +358,173 @@ ROCWMMA_DEVICE static inline void { // Transpose B and then apply lds data layout store_matrix_coop_sync( - ldsAddr, - applyDataLayout(applyTranspose(grBuffB)), - ldsld, - waveIndexB); + ldsAddr, applyDataLayout(applyTranspose(grBuffB)), ldsld, waveIndexB); } -// Global read (macro tile) -using LRBuffA = fragment; -using LRBuffB = ApplyTranspose_t; -using GRBuffC = fragment; -using AccumBuffInt = fragment; - +// Local A reads for warp tile gemm, non-cooperative ROCWMMA_DEVICE static inline void - localReadA(LRBuffA& fragsA, InputT const* ldsAddrA, uint32_t ldsld) + localReadA(MfmaFragA (&fragsA)[BLOCKS_X], InputT const* ldsAddrA, uint32_t ldsld) { - constexpr uint32_t VW = 4; - - using Profile - = rocwmma::LayoutProfile::ColInt; + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; + // Each A block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); - using Loader = OpaqueLoad; - - // Load then implicit pack - Loader::exec(fragsA.mAccess, ldsAddrA, ldsld); +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + LRFragA tmp; + load_matrix_sync(tmp, ldsAddrA, ldsld); + fragsA[i] = applyDataLayout(tmp); - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63) - // { - // auto reg = 0u; - // auto x0 = fragsA.mAccess.data[0]; - // auto x1 = fragsA.mAccess.data[1]; - // auto x2 = fragsA.mAccess.data[2]; - // auto x3 = fragsA.mAccess.data[3]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } + ldsAddrA += blockStep; + } } // Local B reads for warp tile gemm, non-cooperative ROCWMMA_DEVICE static inline void - localReadB(LRBuffB& fragsB, InputT const* ldsAddrB, uint32_t ldsld) + localReadB(MfmaFragB (&fragsB)[BLOCKS_Y], InputT const* ldsAddrB, uint32_t ldsld) { - // How to choose? Comes from the IOConfig? - constexpr uint32_t VW = 4; - - using Profile - = rocwmma::LayoutProfile::ColInt; + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; - using MatrixLayout = typename Profile::MatrixLayout; - using DataLayout = typename Profile::DataLayout; + // Each B block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); - using Loader = OpaqueLoad; +#pragma unroll + for(int i = 0; i < BLOCKS_Y; i++) + { + LRFragB tmp; + load_matrix_sync(tmp, ldsAddrB, ldsld); - // Load then implicit pack - Loader::exec(reinterpret_cast(fragsB).mAccess, ldsAddrB, ldsld); + // Transform back to MFMA tile + fragsB[i] = applyDataLayout(applyTranspose(tmp)); - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63) - // { - // auto reg = 0u; - // auto x0 = fragsB.mAccess.data[0]; - // auto x1 = fragsB.mAccess.data[1]; - // auto x2 = fragsB.mAccess.data[2]; - // auto x3 = fragsB.mAccess.data[3]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } + ldsAddrB += blockStep; + } } // Global C reads for warp tile gemm, non-cooperative -ROCWMMA_DEVICE static inline void globalReadC(GRBuffC& fragsC, OutputT const* gAddrC, uint32_t ldc) +ROCWMMA_DEVICE static inline void + globalReadC(MfmaFragC (&fragC)[BLOCKS_X][BLOCKS_Y], OutputT const* gAddrC, uint32_t ldc) { - // How to choose? Comes from the IOConfig? - constexpr uint32_t VW = 4; - - using Profile - = rocwmma::LayoutProfile::RowInt; - - using MatrixLayout = typename Profile::MatrixLayout; - using DataLayout = typename Profile::DataLayout; - - using Loader = OpaqueLoad; - - // Load then implicit pack - GRBuffC tmp; - Loader::exec(tmp.mAccess, gAddrC, ldc); - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // auto reg = 0u; - // auto x0 = tmp.mAccess.data[0]; - // auto x1 = tmp.mAccess.data[1]; - // auto x2 = tmp.mAccess.data[2]; - // auto x3 = tmp.mAccess.data[3]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - //MatrixLayout::debug(); - { - // Post load to accum format + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Iterative offsets for each C block in the wave tile + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldc); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldc); #pragma unroll - for(int i = 0; i < 4u; i++) + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) { - fragsC.mAccess.data[0 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 0]; - fragsC.mAccess.data[1 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 1]; - fragsC.mAccess.data[2 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 2]; - fragsC.mAccess.data[3 * 16 + 0 + i] = tmp.mAccess.data[i * 16 + 0 + 3]; - - fragsC.mAccess.data[0 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 0]; - fragsC.mAccess.data[1 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 1]; - fragsC.mAccess.data[2 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 2]; - fragsC.mAccess.data[3 * 16 + 4 + i] = tmp.mAccess.data[i * 16 + 4 + 3]; - - fragsC.mAccess.data[0 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 0]; - fragsC.mAccess.data[1 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 1]; - fragsC.mAccess.data[2 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 2]; - fragsC.mAccess.data[3 * 16 + 8 + i] = tmp.mAccess.data[i * 16 + 8 + 3]; - - fragsC.mAccess.data[0 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 0]; - fragsC.mAccess.data[1 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 1]; - fragsC.mAccess.data[2 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 2]; - fragsC.mAccess.data[3 * 16 + 12 + i] = tmp.mAccess.data[i * 16 + 12 + 3]; + load_matrix_sync(fragC[i][j], gAddrC + offsetY, ldc); + offsetY += blockStepY; } + gAddrC += blockStepX; } - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // auto reg = 0u; - // auto x0 = fragsC.mAccess.data[12]; - // auto x1 = fragsC.mAccess.data[13]; - // auto x2 = fragsC.mAccess.data[14]; - // auto x3 = fragsC.mAccess.data[15]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } } // Global D reads for warp tile gemm, non-cooperative -ROCWMMA_DEVICE static inline void globalWriteD(OutputT* gAddrD, GRBuffC const& fragsD, uint32_t ldd) +ROCWMMA_DEVICE static inline void + globalWriteD(OutputT* gAddrD, MfmaFragD const (&fragsD)[BLOCKS_X][BLOCKS_Y], uint32_t ldd) { - // How to choose? Comes from the IOConfig? - constexpr uint32_t VW = 4; - - using Profile - = rocwmma::LayoutProfile::RowInt; - - using MatrixLayout = typename Profile::MatrixLayout; - using DataLayout = typename Profile::DataLayout; - - using Storer = OpaqueStore; - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // auto reg = 0u; - // auto x0 = fragsD.mAccess.data[0]; - // auto x1 = fragsD.mAccess.data[16]; - // auto x2 = fragsD.mAccess.data[32]; - // auto x3 = fragsD.mAccess.data[48]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - - // Pre-store to output fmt - GRBuffC tmp; - // tmp.mAccess.data[0] = fragsD.mAccess.data[0]; - // tmp.mAccess.data[1] = fragsD.mAccess.data[16]; - // tmp.mAccess.data[2] = fragsD.mAccess.data[32]; - // tmp.mAccess.data[3] = fragsD.mAccess.data[48]; - // tmp.mAccess.data[4] = fragsD.mAccess.data[4]; - // tmp.mAccess.data[5] = fragsD.mAccess.data[20]; - // tmp.mAccess.data[6] = fragsD.mAccess.data[36]; - // tmp.mAccess.data[7] = fragsD.mAccess.data[52]; - // tmp.mAccess.data[8] = fragsD.mAccess.data[8]; - // tmp.mAccess.data[9] = fragsD.mAccess.data[24]; - // tmp.mAccess.data[10] = fragsD.mAccess.data[40]; - // tmp.mAccess.data[11] = fragsD.mAccess.data[56]; - // tmp.mAccess.data[12] = fragsD.mAccess.data[12]; - // tmp.mAccess.data[13] = fragsD.mAccess.data[28]; - // tmp.mAccess.data[14] = fragsD.mAccess.data[44]; - // tmp.mAccess.data[15] = fragsD.mAccess.data[60]; - // tmp.mAccess.data[16] = fragsD.mAccess.data[1]; - // tmp.mAccess.data[17] = fragsD.mAccess.data[17]; - // tmp.mAccess.data[18] = fragsD.mAccess.data[33]; - // tmp.mAccess.data[19] = fragsD.mAccess.data[49]; - // tmp.mAccess.data[20] = fragsD.mAccess.data[5]; - // tmp.mAccess.data[21] = fragsD.mAccess.data[21]; - // tmp.mAccess.data[22] = fragsD.mAccess.data[37]; - // tmp.mAccess.data[23] = fragsD.mAccess.data[53]; - // tmp.mAccess.data[24] = fragsD.mAccess.data[9]; - // tmp.mAccess.data[25] = fragsD.mAccess.data[25]; - // tmp.mAccess.data[26] = fragsD.mAccess.data[41]; - // tmp.mAccess.data[27] = fragsD.mAccess.data[57]; - // tmp.mAccess.data[28] = fragsD.mAccess.data[13]; - // tmp.mAccess.data[29] = fragsD.mAccess.data[29]; - // tmp.mAccess.data[30] = fragsD.mAccess.data[45]; - // tmp.mAccess.data[31] = fragsD.mAccess.data[61]; - // tmp.mAccess.data[32] = fragsD.mAccess.data[2]; - // tmp.mAccess.data[33] = fragsD.mAccess.data[18]; - // tmp.mAccess.data[34] = fragsD.mAccess.data[34]; - // tmp.mAccess.data[35] = fragsD.mAccess.data[50]; - // tmp.mAccess.data[36] = fragsD.mAccess.data[6]; - // tmp.mAccess.data[37] = fragsD.mAccess.data[22]; - // tmp.mAccess.data[38] = fragsD.mAccess.data[38]; - // tmp.mAccess.data[39] = fragsD.mAccess.data[54]; - // tmp.mAccess.data[40] = fragsD.mAccess.data[10]; - // tmp.mAccess.data[41] = fragsD.mAccess.data[26]; - // tmp.mAccess.data[42] = fragsD.mAccess.data[42]; - // tmp.mAccess.data[43] = fragsD.mAccess.data[58]; - // tmp.mAccess.data[44] = fragsD.mAccess.data[14]; - // tmp.mAccess.data[45] = fragsD.mAccess.data[30]; - // tmp.mAccess.data[46] = fragsD.mAccess.data[46]; - // tmp.mAccess.data[47] = fragsD.mAccess.data[62]; - // tmp.mAccess.data[48] = fragsD.mAccess.data[3]; - // tmp.mAccess.data[49] = fragsD.mAccess.data[19]; - // tmp.mAccess.data[50] = fragsD.mAccess.data[35]; - // tmp.mAccess.data[51] = fragsD.mAccess.data[51]; - // tmp.mAccess.data[52] = fragsD.mAccess.data[7]; - // tmp.mAccess.data[53] = fragsD.mAccess.data[23]; - // tmp.mAccess.data[54] = fragsD.mAccess.data[39]; - // tmp.mAccess.data[55] = fragsD.mAccess.data[55]; - // tmp.mAccess.data[56] = fragsD.mAccess.data[11]; - // tmp.mAccess.data[57] = fragsD.mAccess.data[27]; - // tmp.mAccess.data[58] = fragsD.mAccess.data[43]; - // tmp.mAccess.data[59] = fragsD.mAccess.data[59]; - // tmp.mAccess.data[60] = fragsD.mAccess.data[15]; - // tmp.mAccess.data[61] = fragsD.mAccess.data[31]; - // tmp.mAccess.data[62] = fragsD.mAccess.data[47]; - // tmp.mAccess.data[63] = fragsD.mAccess.data[63]; + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Iterative offsets for each D block in the warp tile + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldd); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldd); + #pragma unroll - for(int i = 0; i < 4u; i++) + for(int i = 0; i < BLOCKS_X; i++) { - tmp.mAccess.data[i * 16 + 0 + 0] = fragsD.mAccess.data[0 * 16 + 0 + i]; - tmp.mAccess.data[i * 16 + 0 + 1] = fragsD.mAccess.data[1 * 16 + 0 + i]; - tmp.mAccess.data[i * 16 + 0 + 2] = fragsD.mAccess.data[2 * 16 + 0 + i]; - tmp.mAccess.data[i * 16 + 0 + 3] = fragsD.mAccess.data[3 * 16 + 0 + i]; - - tmp.mAccess.data[i * 16 + 4 + 0] = fragsD.mAccess.data[0 * 16 + 4 + i]; - tmp.mAccess.data[i * 16 + 4 + 1] = fragsD.mAccess.data[1 * 16 + 4 + i]; - tmp.mAccess.data[i * 16 + 4 + 2] = fragsD.mAccess.data[2 * 16 + 4 + i]; - tmp.mAccess.data[i * 16 + 4 + 3] = fragsD.mAccess.data[3 * 16 + 4 + i]; - - tmp.mAccess.data[i * 16 + 8 + 0] = fragsD.mAccess.data[0 * 16 + 8 + i]; - tmp.mAccess.data[i * 16 + 8 + 1] = fragsD.mAccess.data[1 * 16 + 8 + i]; - tmp.mAccess.data[i * 16 + 8 + 2] = fragsD.mAccess.data[2 * 16 + 8 + i]; - tmp.mAccess.data[i * 16 + 8 + 3] = fragsD.mAccess.data[3 * 16 + 8 + i]; - - tmp.mAccess.data[i * 16 + 12 + 0] = fragsD.mAccess.data[0 * 16 + 12 + i]; - tmp.mAccess.data[i * 16 + 12 + 1] = fragsD.mAccess.data[1 * 16 + 12 + i]; - tmp.mAccess.data[i * 16 + 12 + 2] = fragsD.mAccess.data[2 * 16 + 12 + i]; - tmp.mAccess.data[i * 16 + 12 + 3] = fragsD.mAccess.data[3 * 16 + 12 + i]; + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + store_matrix_sync(gAddrD + offsetY, fragsD[i][j], ldd); + offsetY += blockStepY; + } + gAddrD += blockStepX; } - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // if(threadIdx.x == 0) - // { - // printf("D Before STORE\n"); - // printf("Count: %d\n", tmp.num_elements); - // } - // auto reg = 0u; - // auto x0 = tmp.mAccess.data[0]; - // auto x1 = tmp.mAccess.data[16]; - // auto x2 = tmp.mAccess.data[32]; - // auto x3 = tmp.mAccess.data[48]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - - // Load then implicit pack - Storer::exec(gAddrD, tmp.mAccess, ldd); } -// Performs warp tile mfma -ROCWMMA_DEVICE static inline void mfma(AccumBuffInt& fragsAccOut, - LRBuffA const& fragsA, - LRBuffB const& fragsB, - AccumBuffInt const& fragsAccIn) +// Broadcast value to fragments in warp tile +template +ROCWMMA_DEVICE static inline void fill(FragT (&frags)[BLOCKS_X][BLOCKS_Y], + GetDataType_t value) { - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // auto x0 = fragsA.mAccess.data[0]; - // auto x1 = fragsA.mAccess.data[1]; - // auto x2 = fragsA.mAccess.data[2]; - // auto x3 = fragsA.mAccess.data[3]; - // printf("(A)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - - // x0 = fragsB.mAccess.data[0]; - // x1 = fragsB.mAccess.data[1]; - // x2 = fragsB.mAccess.data[2]; - // x3 = fragsB.mAccess.data[3]; - // printf("(B)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - // Need to get the MFMA tile size from the IO traits somehow - constexpr static uint32_t MFMFA_TILE = 16u; - - // From here, need to 'unpack' the interleaved data - // Should be 16 registers, need to re-order them in groups of 4 - LRBuffA tmpA; - LRBuffB tmpB; #pragma unroll - for(int i = 0; i < 4u; i++) + for(int i = 0; i < BLOCKS_X; i++) { - tmpA.mAccess.data[i * 4 + 0] = fragsA.mAccess.data[0 * 4 + i]; - tmpA.mAccess.data[i * 4 + 1] = fragsA.mAccess.data[1 * 4 + i]; - tmpA.mAccess.data[i * 4 + 2] = fragsA.mAccess.data[2 * 4 + i]; - tmpA.mAccess.data[i * 4 + 3] = fragsA.mAccess.data[3 * 4 + i]; - - tmpB.mAccess.data[i * 4 + 0] = fragsB.mAccess.data[0 * 4 + i]; - tmpB.mAccess.data[i * 4 + 1] = fragsB.mAccess.data[1 * 4 + i]; - tmpB.mAccess.data[i * 4 + 2] = fragsB.mAccess.data[2 * 4 + i]; - tmpB.mAccess.data[i * 4 + 3] = fragsB.mAccess.data[3 * 4 + i]; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + fill_fragment(frags[i][j], value); + } } +} - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // auto x0 = tmpA.mAccess.data[12]; - // auto x1 = tmpA.mAccess.data[13]; - // auto x2 = tmpA.mAccess.data[14]; - // auto x3 = tmpA.mAccess.data[15]; - // printf("(A)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - - // x0 = tmpB.mAccess.data[12]; - // x1 = tmpB.mAccess.data[13]; - // x2 = tmpB.mAccess.data[14]; - // x3 = tmpB.mAccess.data[15]; - // printf("(B)Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - - // Iterate over MFMA input requirements - // A = 16 regs unpacked, 8 packed - // B = 16 regs unpacked, 8 packed - // Accum = 64 regs unpacked/packed - // MFMA blocks = 16 x 4 regs - // Iterate through A - major - auto bIt = makeVectorIterator<2u>(tmpB.mStorage).begin(); - auto const accumInIt = makeVectorIterator<4u>(fragsAccOut.mStorage).begin(); - auto accumOutIt = makeVectorIterator<4u>(fragsAccOut.mStorage).begin(); - - using MMA = Mfma; - +// Performs warp tile mfma +ROCWMMA_DEVICE static inline void mfma(MfmaFragAcc (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], + MfmaFragA const (&fragsA)[BLOCKS_X], + MfmaFragB const (&fragsB)[BLOCKS_Y], + MfmaFragAcc const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) +{ #pragma unroll - for(int j = 0; j < 4u; j++) + for(int i = 0; i < BLOCKS_X; i++) { - auto aIt = makeVectorIterator<2u>(tmpA.mStorage).begin(); #pragma unroll - for(int i = 0; i < 4u; i++) + for(int j = 0; j < BLOCKS_Y; j++) { - // mma functions operate on packed vectors - *accumOutIt = MMA::exec(*aIt, *bIt, *accumInIt); - aIt++; - accumInIt++; - accumOutIt++; + mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); } - bIt++; } - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // if(threadIdx.x == 0) - // { - // printf("Count: %d\n", fragsAccOut.num_elements); - // } - // auto reg = 0u; - // auto x0 = fragsAccOut.mAccess.data[0]; - // auto x1 = fragsAccOut.mAccess.data[1]; - // auto x2 = fragsAccOut.mAccess.data[2]; - // auto x3 = fragsAccOut.mAccess.data[3]; - // printf("Thread %d: %#010x %#010x %#010x %#010x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } -} - -// Broadcast value to fragments in warp tile -template -ROCWMMA_DEVICE static inline void fill(FragT& frags, GetDataType_t value) -{ - fill_fragment(frags, value); } // Uniform multiply - add (FMA) // Performs D = alpha * acc + beta * C, where alpha, beta are uniform scalars -ROCWMMA_DEVICE static inline void uniformFma(GRBuffC& fragsD, - ComputeT alpha, - AccumBuffInt const& fragsAcc, - ComputeT beta, - GRBuffC const& fragsC) +ROCWMMA_DEVICE static inline void uniformFma(MfmaFragD (&fragsD)[BLOCKS_X][BLOCKS_Y], + ComputeT alpha, + MfmaFragAcc const (&fragsAcc)[BLOCKS_X][BLOCKS_Y], + ComputeT beta, + MfmaFragC const (&fragsC)[BLOCKS_X][BLOCKS_Y]) { - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // if(threadIdx.x == 0) - // { - // printf("Count: %d\n", fragsAcc.num_elements); - // } - // auto reg = 0u; - // auto x0 = fragsAcc.mAccess.data[0]; - // auto x1 = fragsAcc.mAccess.data[1]; - // auto x2 = fragsAcc.mAccess.data[2]; - // auto x3 = fragsAcc.mAccess.data[3]; - // printf("Thread %d: %#010x %#010x %#010x %#010x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // if(threadIdx.x == 0) - // { - // printf("Count: %d\n", fragsC.num_elements); - // } - // auto reg = 0u; - // auto x0 = fragsC.mAccess.data[0]; - // auto x1 = fragsC.mAccess.data[1]; - // auto x2 = fragsC.mAccess.data[2]; - // auto x3 = fragsC.mAccess.data[3]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } - - static constexpr uint32_t ChunkFactor = 2u; - static constexpr uint32_t ChunkSize = 64u / ChunkFactor; - auto dIt = makeVectorIterator(fragsD.mAccess).begin(); - auto const accumIt = makeVectorIterator(fragsAcc.mAccess).begin(); - auto const cIt = makeVectorIterator(fragsC.mAccess).begin(); - - for(int k = 0; k < fragsD.num_elements / ChunkFactor; k++) - { - // Perform computation in ComputeT and cast back to OutputT - (*dIt).data[k] = static_cast(alpha * (*accumIt).data[k] - + beta * static_cast((*cIt).data[k])); - } - - dIt++; - accumIt++; - cIt++; - - for(int k = 0; k < fragsD.num_elements / ChunkFactor; k++) +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) { - // Perform computation in ComputeT and cast back to OutputT - (*dIt).data[k] = static_cast(alpha * (*accumIt).data[k] - + beta * static_cast((*cIt).data[k])); +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int k = 0; k < fragsD[i][j].num_elements; k++) + { + // Perform computation in ComputeT and cast back to OutputT + fragsD[i][j].x[k] = static_cast( + alpha * fragsAcc[i][j].x[k] + beta * static_cast(fragsC[i][j].x[k])); + } + } } - - // if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x <= 63 && threadIdx.y == 0) - // { - // if(threadIdx.x == 0) - // { - // printf("D AFTER UNIFORM FMA\n"); - // printf("Count: %d\n", fragsD.num_elements); - // } - // auto reg = 0u; - // auto x0 = fragsD.mAccess.data[0]; - // auto x1 = fragsD.mAccess.data[16]; - // auto x2 = fragsD.mAccess.data[32]; - // auto x3 = fragsD.mAccess.data[48]; - // printf("Thread %d: %#06x %#06x %#06x %#06x\n", threadIdx.x, reinterpret_cast(x0), reinterpret_cast(x1), reinterpret_cast(x2), reinterpret_cast(x3)); - // } } -//ROCWMMA_KERNEL void gemm_rocwmma_d(uint32_t m, -//ROCWMMA_KERNEL void __attribute__((amdgpu_num_vgpr(0))) gemm_rocwmma_d(uint32_t m, -ROCWMMA_KERNEL void __launch_bounds__(1024) gemm_rocwmma_d( - uint32_t m, - //ROCWMMA_KERNEL void __attribute__((amdgpu_waves_per_eu(1))) gemm_rocwmma_d(uint32_t m, - uint32_t n, - uint32_t k, - InputT const* a, - InputT const* b, - OutputT const* c, - OutputT* d, - uint32_t lda, - uint32_t ldb, - uint32_t ldc, - uint32_t ldd, - ComputeT alpha, - ComputeT beta) +ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, + uint32_t n, + uint32_t k, + InputT const* a, + InputT const* b, + OutputT const* c, + OutputT* d, + uint32_t lda, + uint32_t ldb, + uint32_t ldc, + uint32_t ldd, + ComputeT alpha, + ComputeT beta) { if constexpr(!ROCWMMA_ARCH_HOST) { @@ -938,7 +637,7 @@ ROCWMMA_KERNEL void __launch_bounds__(1024) gemm_rocwmma_d( /// /// Initialize accumulation frags /// - AccumBuffInt fragsAcc; + MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; fill(fragsAcc, 0.0f); /// @@ -949,27 +648,19 @@ ROCWMMA_KERNEL void __launch_bounds__(1024) gemm_rocwmma_d( /// /// Accumulate A * B for all mfma frags in warp tile /// - // - LDS Triple buffer - // - LDS no buffer-> tiny m/n large K - // - unroll K to have more work - // - __restrict__ - // for(uint32_t currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) { - // Make sure that all waves have finished reading / writing to lds for currentK. - synchronize_workgroup(); - - // Prefetch next round of global frags - globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); - globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); - - LRBuffA fragsA; - LRBuffB fragsB; + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; // Local read mfma frags from first LDS buffer localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + // Prefetch next round of global frags + globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); + // Advance offsets to next k step globalReadOffsetA += kStepOffsetA; globalReadOffsetB += kStepOffsetB; @@ -981,93 +672,41 @@ ROCWMMA_KERNEL void __launch_bounds__(1024) gemm_rocwmma_d( localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + // Make sure that all waves have finished reading / writing to lds for currentK. + synchronize_workgroup(); + // Swap Lds buffers auto* tmp = ldsPtrLo; ldsPtrLo = ldsPtrHi; ldsPtrHi = tmp; - - // Scheduling - - // // VMEM read - // __builtin_amdgcn_sched_group_barrier(32, 2, 0); - // // DS read - // __builtin_amdgcn_sched_group_barrier(256, 16, 0); - // // Non-VMEM - // __builtin_amdgcn_sched_group_barrier(1, 16, 0); - // // MFMA - // __builtin_amdgcn_sched_group_barrier(8, 4, 0); - // // DS read - // __builtin_amdgcn_sched_group_barrier(256, 16, 1); - // // // Non-VMEM - // __builtin_amdgcn_sched_group_barrier(1, 16, 1); - // // MFMA - // __builtin_amdgcn_sched_group_barrier(8, 4, 1); - // // DS write - // __builtin_amdgcn_sched_group_barrier(512, 32, 0); - - ////////// Works good - 127.46 - // VMEM read - __builtin_amdgcn_sched_group_barrier(32, 4, 0); - // DS read - __builtin_amdgcn_sched_group_barrier(256, 64, 0); - // SALU - __builtin_amdgcn_sched_group_barrier(4, 256, 0); - // VALU - __builtin_amdgcn_sched_group_barrier(2, 256, 0); - // MFMA - __builtin_amdgcn_sched_group_barrier(8, 16, 0); - // DS write - __builtin_amdgcn_sched_group_barrier(512, 64, 0); - ////////////////// } - // Make sure that all waves have finished reading / writing to lds for currentK. - synchronize_workgroup(); - /// /// Start loading C /// using MfmaFragCMap1d = GetDataLayout_t; using MfmaFragDMap1d = GetDataLayout_t; - GRBuffC fragsC; + MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - // /// - // /// Clean up tail A * B - // /// - LRBuffA fragsA; - LRBuffB fragsB; + /// + /// Clean up tail A * B + /// + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; - // // Local read mfma frags + // Local read mfma frags localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); mfma(fragsAcc, fragsA, fragsB, fragsAcc); - // /// - // /// D = alpha * accum + beta * C - // /// - GRBuffC fragsD; + /// + /// D = alpha * accum + beta * C + /// + MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); - //globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), reinterpret_cast(fragsAcc), ldd); globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); - - ////////// Works good - 127.46 - // DS read - __builtin_amdgcn_sched_group_barrier(256, 64, 0); - // VMEM read - __builtin_amdgcn_sched_group_barrier(32, 64, 0); - - // MFMA - __builtin_amdgcn_sched_group_barrier(8, 16, 0); - // SALU - __builtin_amdgcn_sched_group_barrier(4, 256, 0); - // VALU - __builtin_amdgcn_sched_group_barrier(2, 512, 0); - - // VMEM write - __builtin_amdgcn_sched_group_barrier(512, 64, 0); - ////////////////// } } @@ -1141,12 +780,6 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, fillRand(matrixA.data(), m, k); fillRand(matrixB.data(), k, n); fillRand(matrixC.data(), m, n); - //fillEnc(matrixA.data(), m, k); - //printEnc(matrixA.data(), m, k); - //fillEnc(matrixB.data(), k, n); - //printEnc(matrixB.data(), k, n); - //fillEnc(matrixC.data(), m, n); - //printEnc(matrixC.data(), m, n); std::cout << "Initializing device data..." << std::endl; @@ -1208,7 +841,7 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, beta); }; - constexpr uint32_t warmups = 50u; + constexpr uint32_t warmups = 2u; constexpr uint32_t recordRuns = 5u; // Warm-up runs, not recorded @@ -1255,7 +888,7 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, << ldc << ", " << ldd << ", " << elapsedTimeMs << ", " << gFlops << ", " << tFlopsPerSec << std::endl; -#if 1 +#if !NDEBUG std::cout << "Validating result with reference..." << std::endl; @@ -1285,11 +918,6 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, auto res = compareEqual(matrixD.data(), matrixD_ref.data(), m * n); - //std::cout << "Reference: \n"; - //printData(matrixD_ref.data(), m, n); - //std::cout << "Actual:\n"; - //printData(matrixD.data(), m, n); - if(std::get<0>(res) == false) { std::cout << "FAILED\n"; @@ -1315,7 +943,5 @@ ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, int main() { gemm_test(7168, 7168, 7168, 2, 2); - //gemm_test(8192, 8192, 8192, 2, 2); - //gemm_test(128, 128, 16, 2, 2); return 0; -} +} \ No newline at end of file From 43c96fe7a775c0707451733db245f5d921bd4252 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 2 Dec 2024 17:34:32 -0700 Subject: [PATCH 25/36] Skip tests on invalid layout condition for BlockK --- test/unit/layout_traits_test/device/layout_traits_int.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp index b4323457..8f0e2e39 100644 --- a/test/unit/layout_traits_test/device/layout_traits_int.hpp +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -1333,7 +1333,7 @@ namespace rocwmma constexpr uint32_t kpt = kPerThread(); // Must satisfy layout requirement - if constexpr(dpt > 0u && kpt > 0u) + if constexpr(dpt > 0u && kpt > 0u && BlockK >= kpt) { // clang-format off result &= matrixLayoutTraitsTestInterleaved0(); From 1aeb382389d3611317e0b4e435ec69dfaf11dde8 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Mon, 2 Dec 2024 17:35:50 -0700 Subject: [PATCH 26/36] Add a softer warning for unsupported transform attempts --- .../internal/layout/register_layout_transforms.hpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 62233661..2de7ae6b 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -70,7 +70,13 @@ namespace rocwmma && (!traits_lhs::is_register_layout || !traits_rhs::is_register_layout || !is_layout_orthogonal_v)>> { - static_assert(0, "Register layout transform is not supported"); + template + ROCWMMA_UNSUPPORTED_IMPL("Register layout transform is not supported") + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // No-op + return v; + } }; // Apply paths between orthogonal transforms From abb085f39403386509f33c8baa3b7976f63a7ad2 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 4 Dec 2024 09:38:34 -0700 Subject: [PATCH 27/36] Adjust MaxVWSelector to fit more layout constraints --- .../include/rocwmma/internal/io_layout.hpp | 22 ++++++++++++++++--- .../internal/layout/matrix_layout_impl.hpp | 4 ++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index fc713c7a..79813ad3 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -63,11 +63,27 @@ namespace rocwmma static constexpr bool ElementCountTest = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0); - // Currently, all layouts are using ColOrthoVW. This means that VW must be less than BlockK - static constexpr bool LeadingDimTest = (TestWidth <= BlockK); + // Check the layout geometry. Avoids triggering static asserts for invalid layout. + // matrix_a (BlockDim <= 32): col_major, row_major -> ColOrtho req: BlockKStride <= BlockK + // (BlockDim > 32): col_major -> ColInline req: MaxVW <= BlockDim + // row_major -> ColOrtho req: BlockKStride <= BlockK + // matrix_b (BlockDim <= 32): col_major, row_major -> RowOrtho req: BlockKStride <= BlockK + // (BlockDim > 32): row_major -> ColInline req: MaxVW <= BlockDim + // col_major -> ColOrtho req: BlockKStride <= BlockK + // + // Note: BlockKStride is non-interleaved layout specific, and determines whether the gathered + // data at a specific MaxVW fits within BlockK dimension. + static constexpr bool BlockDimTest = TestWidth <= BlockDim; + static constexpr bool BlockKTest = (Constants::AMDGCN_WAVE_SIZE * TestWidth / min(BlockDim, Constants::AMDGCN_WAVE_SIZE)) <= BlockK; + + // TODO: These could really be more layout specific. + // This is limiting for small BlockDim and large K. + static constexpr bool MatrixATest = is_same_v ? (BlockDimTest && BlockKTest) : BlockKTest; + static constexpr bool MatrixBTest = is_same_v ? (BlockDimTest && BlockKTest) : BlockKTest; + static constexpr bool VWFitnessTest = (is_same_v && MatrixATest) || (is_same_v && MatrixBTest); // Decide on final MaxVW - static constexpr uint32_t MaxVectorWidth = (ElementCountTest && LeadingDimTest) + static constexpr uint32_t MaxVectorWidth = (ElementCountTest && VWFitnessTest) ? TestWidth : MaxVWSelector= BlockDimStride_X, "BlockDim must be larger than BlockDimStride_X"); static_assert(BlockDim % BlockDimStride_X == 0, @@ -388,6 +390,8 @@ namespace rocwmma = DimPerThread * KPerThread * BlockDimSegs; // Sanity checks for strides sizes + static_assert(MaxVectorWidth <= BlockDim, + "MaxVectorWidth cannot exceed BlockDim"); static_assert(BlockDim >= BlockDimStride_X, "BlockDim must be larger than BlockDimStride_X"); static_assert(BlockDim % BlockDimStride_X == 0, From a420f9bb7eab369a70bf451087cb7c19566d9c0c Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 4 Dec 2024 09:40:51 -0700 Subject: [PATCH 28/36] Update / correct non-interleaved layout tests --- test/unit/layout_test/device/col_layout.hpp | 31 +++++++---------- test/unit/layout_test/device/colnt_layout.hpp | 31 +++++++---------- test/unit/layout_test/device/row_layout.hpp | 32 ++++++++---------- test/unit/layout_test/device/rownt_layout.hpp | 33 +++++++------------ 4 files changed, 49 insertions(+), 78 deletions(-) diff --git a/test/unit/layout_test/device/col_layout.hpp b/test/unit/layout_test/device/col_layout.hpp index a5d8bc9a..b189983c 100644 --- a/test/unit/layout_test/device/col_layout.hpp +++ b/test/unit/layout_test/device/col_layout.hpp @@ -56,41 +56,34 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockM, - KDim = BlockN, + BlockK = BlockN, - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = MaxVectorWidth }; - using IOTraits = IOTraits; + using IOTraits = IOTraits; using LayoutT = conditional_t< is_same_v, - MatrixLayout::ColInlineVW, - MatrixLayout::ColOrthoVW>; + MatrixLayout::ColInlineVW, + MatrixLayout::ColOrthoVW>; using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } diff --git a/test/unit/layout_test/device/colnt_layout.hpp b/test/unit/layout_test/device/colnt_layout.hpp index 84b40b19..47ccb0ef 100644 --- a/test/unit/layout_test/device/colnt_layout.hpp +++ b/test/unit/layout_test/device/colnt_layout.hpp @@ -56,39 +56,32 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockM, - KDim = BlockN, + BlockK = BlockN, - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = std::is_same_v ? MaxVectorWidth : 1u, }; - using IOTraits = IOTraits; + using IOTraits = IOTraits; using LayoutT = conditional_t< is_same_v, - MatrixLayout::ColOrthoVW, - MatrixLayout::ColOrthoVW>; + MatrixLayout::ColOrthoVW, + MatrixLayout::ColOrthoVW>; using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } diff --git a/test/unit/layout_test/device/row_layout.hpp b/test/unit/layout_test/device/row_layout.hpp index ea9a4898..da201fbf 100644 --- a/test/unit/layout_test/device/row_layout.hpp +++ b/test/unit/layout_test/device/row_layout.hpp @@ -57,39 +57,33 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockN, - KDim = BlockM, + BlockK = BlockM, - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = MaxVectorWidth }; - using IOTraits = IOTraits; + using IOTraits = IOTraits; using LayoutT = conditional_t< is_same_v, - MatrixLayout::RowInlineVW, - MatrixLayout::RowOrthoVW>; + MatrixLayout::RowInlineVW, + MatrixLayout::RowOrthoVW>; + using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } diff --git a/test/unit/layout_test/device/rownt_layout.hpp b/test/unit/layout_test/device/rownt_layout.hpp index c19c9b99..fc8967b8 100644 --- a/test/unit/layout_test/device/rownt_layout.hpp +++ b/test/unit/layout_test/device/rownt_layout.hpp @@ -57,41 +57,32 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockN, - KDim = BlockM, + BlockK = BlockM, - MaxVectorWidth - = std::is_same_v - ? 1 - : detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1, + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = std::is_same_v ? MaxVectorWidth : 1u, }; - using IOTraits = IOTraits; + using IOTraits = IOTraits; using LayoutT = conditional_t< is_same_v, - MatrixLayout::RowOrthoVW, - MatrixLayout::RowOrthoVW>; + MatrixLayout::RowOrthoVW, + MatrixLayout::RowOrthoVW>; using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } From e494ec14ea6a528bdae4c89afa4299ff61ae1f24 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 4 Dec 2024 09:41:35 -0700 Subject: [PATCH 29/36] Prevent sgemm kernel from building on unsupported targets --- samples/perf_sgemm.cpp | 327 +++++++++++++++++++++-------------------- 1 file changed, 165 insertions(+), 162 deletions(-) diff --git a/samples/perf_sgemm.cpp b/samples/perf_sgemm.cpp index 0ca4f2fc..4224b87a 100644 --- a/samples/perf_sgemm.cpp +++ b/samples/perf_sgemm.cpp @@ -472,185 +472,188 @@ ROCWMMA_KERNEL void __launch_bounds__(256) gemm_rocwmma_d(uint32_t m, ComputeT alpha, ComputeT beta) { - /// - /// 2D matrix coordinate setup - /// + if constexpr((bool)ROCWMMA_ARCH_GFX9) + { + /// + /// 2D matrix coordinate setup + /// + + // Tile Sizes + constexpr auto warpTileSize = make_coord2d(WARP_TILE_X, WARP_TILE_Y); + constexpr auto macroTileSize = make_coord2d(MACRO_TILE_X, MACRO_TILE_Y); + + // Local warp coordinate relative to current threadblock (wg). + constexpr auto warpDims = make_coord2d(WARPS_X, WARPS_Y); + auto localWarpCoord = make_coord2d(threadIdx.x / WARP_SIZE, threadIdx.y); + auto localWarpOffset = localWarpCoord * warpTileSize; + + // Global matrix coordinates for C/D + auto macroTileCoord = make_coord2d(blockIdx.x, blockIdx.y) * macroTileSize; + auto warpTileCoord = macroTileCoord + localWarpOffset; + + // Bounds check + auto warpTileBound = warpTileCoord + warpTileSize; + if(get<0>(warpTileBound) > m || get<1>(warpTileBound) > n) + { + return; + } - // Tile Sizes - constexpr auto warpTileSize = make_coord2d(WARP_TILE_X, WARP_TILE_Y); - constexpr auto macroTileSize = make_coord2d(MACRO_TILE_X, MACRO_TILE_Y); + /// + /// 1D global read coordinate setup + /// + using GRBuffAMap1d = GetDataLayout_t; + using GRBuffBMap1d = GetDataLayout_t; - // Local warp coordinate relative to current threadblock (wg). - constexpr auto warpDims = make_coord2d(WARPS_X, WARPS_Y); - auto localWarpCoord = make_coord2d(threadIdx.x / WARP_SIZE, threadIdx.y); - auto localWarpOffset = localWarpCoord * warpTileSize; + // Initial globa read address offsets + auto globalReadOffsetA + = GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); + auto globalReadOffsetB + = GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); - // Global matrix coordinates for C/D - auto macroTileCoord = make_coord2d(blockIdx.x, blockIdx.y) * macroTileSize; - auto warpTileCoord = macroTileCoord + localWarpOffset; + // Incremental global read address offsets + auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); + auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); - // Bounds check - auto warpTileBound = warpTileCoord + warpTileSize; - if(get<0>(warpTileBound) > m || get<1>(warpTileBound) > n) - { - return; - } + /// + /// Cooperative config for global read A / B + /// - /// - /// 1D global read coordinate setup - /// - using GRBuffAMap1d = GetDataLayout_t; - using GRBuffBMap1d = GetDataLayout_t; - - // Initial globa read address offsets - auto globalReadOffsetA - = GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); - auto globalReadOffsetB - = GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); - - // Incremental global read address offsets - auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); - auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); - - /// - /// Cooperative config for global read A / B - /// - - // WorkItems will be split up by minimum IOCount to perform either global read or local write. - // These are inputs to cooperative functions. - constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); - - // Scheduling warp order is analogous to row major priority. - // E.g. Wg = (128, 2) = 2x2 warps - // (0, 0) (0, 1) Share Schedule: w0 = (0, 0), w1 = (0, 1), - // (1, 0) (1, 1) w2 = (1, 0), w3 = (1, 1), count = 4 - const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); - - /// - /// Perform initial global pre-fetch - /// - - GRBuffA grBuffA; - GRBuffB grBuffB; - - globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); - globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); - - globalReadOffsetA += kStepOffsetA; - globalReadOffsetB += kStepOffsetB; - - /// - /// Setup LDS addressing - /// This kernel will use 2 separate LDS blocks for pipelining - /// the input prefetching during the accumulation loop - /// - - HIP_DYNAMIC_SHARED(void*, localMemPtr); - using LWBuffAShape = GetIOShape_t; - using LWBuffBShape = GetIOShape_t; - using LWBuffAMap1d = GetDataLayout_t; - using LWBuffBMap1d = GetDataLayout_t; - - constexpr uint32_t ldsWidth = ROCWMMA_K; - constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; - constexpr uint32_t sizeLds = ldsHeight * ldsWidth; - constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; - - auto* ldsPtrLo = reinterpret_cast(localMemPtr); - auto* ldsPtrHi = ldsPtrLo + sizeLds; - - // Local write offsets to start of A / B data - auto ldsWriteOffsetA = 0u; - auto ldsWriteOffsetB - = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); - - // Local read offsets for mfma frags - auto ldsReadOffsetA - = ldsWriteOffsetA - + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); - auto ldsReadOffsetB - = ldsWriteOffsetB - + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); - - /// - /// Write prefetch to local - /// - localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); - localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); - - /// - /// Initialize accumulation frags - /// - MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; - fill(fragsAcc, 0.0f); - - /// - /// Synchronize warps and memory - /// - synchronize_workgroup(); - - /// - /// Accumulate A * B for all mfma frags in warp tile - /// - for(auto currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) - { - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; + // WorkItems will be split up by minimum IOCount to perform either global read or local write. + // These are inputs to cooperative functions. + constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); - // Local read mfma frags from first LDS buffer - localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); - localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + // Scheduling warp order is analogous to row major priority. + // E.g. Wg = (128, 2) = 2x2 warps + // (0, 0) (0, 1) Share Schedule: w0 = (0, 0), w1 = (0, 1), + // (1, 0) (1, 1) w2 = (1, 0), w3 = (1, 1), count = 4 + const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); + + /// + /// Perform initial global pre-fetch + /// + + GRBuffA grBuffA; + GRBuffB grBuffB; - // Prefetch next round of global frags globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); - // Advance offsets to next k step globalReadOffsetA += kStepOffsetA; globalReadOffsetB += kStepOffsetB; - // accum(A * B) - mfma(fragsAcc, fragsA, fragsB, fragsAcc); + /// + /// Setup LDS addressing + /// This kernel will use 2 separate LDS blocks for pipelining + /// the input prefetching during the accumulation loop + /// + + HIP_DYNAMIC_SHARED(void*, localMemPtr); + using LWBuffAShape = GetIOShape_t; + using LWBuffBShape = GetIOShape_t; + using LWBuffAMap1d = GetDataLayout_t; + using LWBuffBMap1d = GetDataLayout_t; + + constexpr uint32_t ldsWidth = ROCWMMA_K; + constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; + constexpr uint32_t sizeLds = ldsHeight * ldsWidth; + constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; + + auto* ldsPtrLo = reinterpret_cast(localMemPtr); + auto* ldsPtrHi = ldsPtrLo + sizeLds; + + // Local write offsets to start of A / B data + auto ldsWriteOffsetA = 0u; + auto ldsWriteOffsetB + = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); + + // Local read offsets for mfma frags + auto ldsReadOffsetA + = ldsWriteOffsetA + + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); + auto ldsReadOffsetB + = ldsWriteOffsetB + + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); + + /// + /// Write prefetch to local + /// + localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + /// + /// Initialize accumulation frags + /// + MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; + fill(fragsAcc, 0.0f); + + /// + /// Synchronize warps and memory + /// + synchronize_workgroup(); - // Write prefetch to second LDS buffer - localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); - localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + /// + /// Accumulate A * B for all mfma frags in warp tile + /// + for(auto currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) + { + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; - // Make sure that all waves have finished reading / writing to lds for currentK. - synchronize_workgroup(); + // Local read mfma frags from first LDS buffer + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); - // Swap Lds buffers - auto* tmp = ldsPtrLo; - ldsPtrLo = ldsPtrHi; - ldsPtrHi = tmp; - } + // Prefetch next round of global frags + globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); + + // Advance offsets to next k step + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + // accum(A * B) + mfma(fragsAcc, fragsA, fragsB, fragsAcc); + + // Write prefetch to second LDS buffer + localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + // Make sure that all waves have finished reading / writing to lds for currentK. + synchronize_workgroup(); + + // Swap Lds buffers + auto* tmp = ldsPtrLo; + ldsPtrLo = ldsPtrHi; + ldsPtrHi = tmp; + } + + /// + /// Start loading C + /// + using MfmaFragCMap1d = GetDataLayout_t; + using MfmaFragDMap1d = GetDataLayout_t; + + MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; + globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - /// - /// Start loading C - /// - using MfmaFragCMap1d = GetDataLayout_t; - using MfmaFragDMap1d = GetDataLayout_t; - - MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; - globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - - /// - /// Clean up tail A * B - /// - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; - - // Local read mfma frags - localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); - localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); - mfma(fragsAcc, fragsA, fragsB, fragsAcc); - - /// - /// D = alpha * accum + beta * C - /// - MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; - uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); - globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); + /// + /// Clean up tail A * B + /// + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + // Local read mfma frags + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + mfma(fragsAcc, fragsA, fragsB, fragsAcc); + + /// + /// D = alpha * accum + beta * C + /// + MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; + uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); + globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); + } } ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, ComputeT beta) From 8c75e884f34b8006f3b88bc43d62ee5e16c0529b Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 5 Dec 2024 10:58:48 -0700 Subject: [PATCH 30/36] Fixes: remove default Format argument to avoid usage mistakes; fix test cases for gfx11 MmaInput and MmaAcc; fix MaxVW selector to fit proper layout fitness --- .../include/rocwmma/internal/io_layout.hpp | 26 +- .../rocwmma/internal/layout/layout.hpp | 4 +- .../layout/register_layout_traits_impl.hpp | 26 +- .../device/layout_traits.hpp | 561 +++++++++++++----- .../device/layout_traits_int.hpp | 8 +- 5 files changed, 436 insertions(+), 189 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 79813ad3..df09d26e 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -63,27 +63,17 @@ namespace rocwmma static constexpr bool ElementCountTest = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0); - // Check the layout geometry. Avoids triggering static asserts for invalid layout. - // matrix_a (BlockDim <= 32): col_major, row_major -> ColOrtho req: BlockKStride <= BlockK - // (BlockDim > 32): col_major -> ColInline req: MaxVW <= BlockDim - // row_major -> ColOrtho req: BlockKStride <= BlockK - // matrix_b (BlockDim <= 32): col_major, row_major -> RowOrtho req: BlockKStride <= BlockK - // (BlockDim > 32): row_major -> ColInline req: MaxVW <= BlockDim - // col_major -> ColOrtho req: BlockKStride <= BlockK - // - // Note: BlockKStride is non-interleaved layout specific, and determines whether the gathered - // data at a specific MaxVW fits within BlockK dimension. - static constexpr bool BlockDimTest = TestWidth <= BlockDim; + // Layout fitness check: + // Basic non-interleaved layouts are classified into *OrthoVW (SOA) and *InlineVW (AOS) formats. + // For any BlockDim/BlockK geometry, we ensure that these layouts come up with the same MaxVW, + // so that the AOS <-> SOA transforms are possible and valid. The followings tests assure this. static constexpr bool BlockKTest = (Constants::AMDGCN_WAVE_SIZE * TestWidth / min(BlockDim, Constants::AMDGCN_WAVE_SIZE)) <= BlockK; - - // TODO: These could really be more layout specific. - // This is limiting for small BlockDim and large K. - static constexpr bool MatrixATest = is_same_v ? (BlockDimTest && BlockKTest) : BlockKTest; - static constexpr bool MatrixBTest = is_same_v ? (BlockDimTest && BlockKTest) : BlockKTest; - static constexpr bool VWFitnessTest = (is_same_v && MatrixATest) || (is_same_v && MatrixBTest); + static constexpr bool OrthoTest = TestWidth <= BlockK; + static constexpr bool InlineTest = TestWidth <= BlockDim; + static constexpr bool LayoutFitnessTest = (BlockKTest && OrthoTest && InlineTest); // Decide on final MaxVW - static constexpr uint32_t MaxVectorWidth = (ElementCountTest && VWFitnessTest) + static constexpr uint32_t MaxVectorWidth = (ElementCountTest && LayoutFitnessTest) ? TestWidth : MaxVWSelector + Format Fmt> struct MmaInput { }; @@ -179,7 +179,7 @@ namespace rocwmma template + Format Fmt> struct MmaAcc { }; diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 019338ff..46181ff9 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -517,8 +517,6 @@ namespace rocwmma constexpr bool TestOpposingFormat = ( (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) - || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) - || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::AOS_INT) || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::SOA_INT) || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) @@ -528,7 +526,29 @@ namespace rocwmma || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::SOA_INT) || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::AOS_INT) || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR)) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + // gfx11 transforms + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11 && traits_rhs::Format == Format::AOS_INT) + ) && (traits_lhs::is_valid && traits_rhs::is_valid); // clang-format on diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp index 35bdba98..e469ffdb 100644 --- a/test/unit/layout_traits_test/device/layout_traits.hpp +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -45,9 +45,8 @@ namespace rocwmma ROCWMMA_HOST bool testLayoutPair(const char* file, const char* line, std::ostream& stream = std::cout) { - constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; - constexpr bool is_layout_orthogonal_result - = rocwmma::is_layout_orthogonal_v; + constexpr bool is_layout_same_result = is_layout_same_v; + constexpr bool is_layout_orthogonal_result = is_layout_orthogonal_v; constexpr bool compare_result = ((is_layout_same_result == ExpectSame) && (is_layout_orthogonal_result == ExpectOrthogonal)); @@ -83,9 +82,8 @@ namespace rocwmma bool DebugOnFail> ROCWMMA_DEVICE bool testLayoutPair(const char* file, uint32_t line) { - constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; - constexpr bool is_layout_orthogonal_result - = rocwmma::is_layout_orthogonal_v; + constexpr bool is_layout_same_result = is_layout_same_v; + constexpr bool is_layout_orthogonal_result = is_layout_orthogonal_v; constexpr bool compare_result = ((is_layout_same_result == ExpectSame) && (is_layout_orthogonal_result == ExpectOrthogonal)); @@ -157,8 +155,12 @@ namespace rocwmma MatrixLayout::RowOrthoVW, DataLayout>; - using MmaInput = RegisterLayout::MmaInput; - using MmaAcc = RegisterLayout::MmaAcc; + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; }; template @@ -240,6 +242,21 @@ namespace rocwmma : ((is_same_v || (bool)ROCWMMA_ARCH_GFX11) ? 1u : 4u)); } + template + ROCWMMA_DEVICE constexpr void debug() + { + if(isFirstThread()) + { + using traits_lhs = layout_traits; + using traits_rhs = layout_traits; + printf("testCompatibleRegisterParams: %d\n", LayoutTraits_impl::testCompatibleRegisterParams()); + printf("MmaDim: %d, MmaDim: %d\n", traits_lhs::MmaDim, traits_rhs::MmaDim); + printf("DataFormat: %d, DataFormat: %d\n", (int)traits_lhs::Format, (int)traits_rhs::Format); + printf("is_valid: %d, is_valid: %d\n", traits_lhs::is_valid, traits_rhs::is_valid); + printf("is_same_dataT: %d\n", is_same_v); + } + } + template mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -372,7 +416,7 @@ namespace rocwmma bool result = true; - // Case is tested in #3 + // Covered in another test case if constexpr(VectorWidth == 1u) { return result; @@ -400,25 +444,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -480,25 +551,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -560,25 +658,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -655,25 +780,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -750,25 +902,53 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -835,25 +1015,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -920,25 +1127,52 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); @@ -1251,8 +1485,7 @@ namespace rocwmma constexpr uint32_t WaveCount = 1u; constexpr uint32_t BlockDim = BlockM; constexpr uint32_t BlockK = BlockN; - constexpr uint32_t MaxVW = rocwmma::detail:: - MaxVWSelector::Result; + constexpr uint32_t MaxVW = detail::MaxVWSelector::Result; bool result = true; result &= dataLayoutTraitsTest(); diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp index 8f0e2e39..3d68b903 100644 --- a/test/unit/layout_traits_test/device/layout_traits_int.hpp +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -129,8 +129,12 @@ namespace rocwmma MatrixLayout::RowOrthoInt, DataLayoutT>; - using MmaInput = RegisterLayout::MmaInput; - using MmaAcc = RegisterLayout::MmaAcc; + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; }; template From 73fb53cefae31444cee1ce3c024298cdb0138df0 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Thu, 5 Dec 2024 15:15:49 -0700 Subject: [PATCH 31/36] Fixup interleaved tests on gfx11 --- .../layout/register_layout_traits_impl.hpp | 23 +- .../device/layout_traits_int.hpp | 275 +++++++++++++----- 2 files changed, 222 insertions(+), 76 deletions(-) diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 46181ff9..67e7c6bf 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -475,20 +475,29 @@ namespace rocwmma constexpr bool TestFormatMatch = (traits_lhs::Format == traits_rhs::Format); if constexpr((traits_lhs::is_interleaved && traits_rhs::is_interleaved) - && ((traits_lhs::is_storage && traits_rhs::is_storage) + && ((traits_lhs::is_storage && traits_rhs::is_storage) || (traits_lhs::is_storage && traits_rhs::is_mma_input) || (traits_lhs::is_mma_input && traits_rhs::is_storage))) { using storage_traits = conditional_t; - // Special case: interleaved layouts - // Check matching thread dims and if either one is == 1u. - // Register contents will be identical, regardless of different formats. - constexpr bool TestIdentityQuirks - = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); + // Gfx11 MmaInput requires some additional transforms + if constexpr((bool)ROCWMMA_ARCH_GFX11 + && (traits_lhs::is_mma_input || traits_rhs::is_mma_input)) + { + return TestCompatibleParams && TestFormatMatch; + } + else + { + // Special case: interleaved layouts + // Check matching thread dims and if either one is == 1u. + // Register contents will be identical, regardless of different formats. + constexpr bool TestIdentityQuirks + = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); - return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); + return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); + } } else { diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp index 3d68b903..ce47a0ec 100644 --- a/test/unit/layout_traits_test/device/layout_traits_int.hpp +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -290,15 +290,32 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); @@ -383,15 +400,32 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); @@ -486,15 +520,32 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_dpt_eq_1, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); @@ -587,15 +638,32 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); @@ -694,15 +762,32 @@ namespace rocwmma // Storage <-> mma layouts // Same MmaDim - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); @@ -805,15 +890,32 @@ namespace rocwmma // Storage <-> mma layouts // Same MmaDim - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); @@ -1186,15 +1288,32 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); - - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + } result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); @@ -1295,15 +1414,33 @@ namespace rocwmma result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); // Storage <-> mma layouts - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); - result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); From 4d6fab8d0f9e348e857fb059b555b69f9e56fd26 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 10 Dec 2024 20:17:49 -0700 Subject: [PATCH 32/36] Allow acc post mma xform to convert gfx11 mma acc quirk into configuration --- .../include/rocwmma/internal/io_config.hpp | 10 ++- .../include/rocwmma/internal/io_layout.hpp | 64 +++++++++---------- .../layout/register_layout_traits_impl.hpp | 10 +-- .../layout/register_layout_transforms.hpp | 43 ++++++++++++- library/include/rocwmma/internal/wmma.hpp | 25 +++----- library/include/rocwmma/rocwmma_impl.hpp | 16 +++-- 6 files changed, 107 insertions(+), 61 deletions(-) diff --git a/library/include/rocwmma/internal/io_config.hpp b/library/include/rocwmma/internal/io_config.hpp index 609f27a6..e9c3db57 100644 --- a/library/include/rocwmma/internal/io_config.hpp +++ b/library/include/rocwmma/internal/io_config.hpp @@ -38,7 +38,6 @@ namespace rocwmma { - /** * \defgroup Rocwmma_ioconf ROCWMMA IOConfig * @brief ROCWMMA fragment input and output configurations @@ -95,6 +94,12 @@ namespace rocwmma using PreMmaXForm = register_layout_transform; + // Currently, only makes sense to have a post-mma transform on acc layouts + using PostMmaXForm = conditional_t, + register_layout_transform, + register_layout_transform_nop>; + using PreStoreXForm = register_layout_transform; @@ -124,6 +129,9 @@ namespace rocwmma using PreMmaXForm = register_layout_transform; + + using PostMmaXForm = register_layout_transform; }; /** @}*/ diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index df09d26e..5950966c 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -153,6 +153,7 @@ namespace rocwmma // Vector size properties constexpr static uint32_t MaxVW = detail:: MaxVWSelector::Result; + constexpr static uint32_t VW = is_same_v || BlockDim > 32u ? MaxVW : 1u; @@ -162,9 +163,7 @@ namespace rocwmma // Matrix Layouts // Small dim mma friendly using SmallDimMatrixLayout - = conditional_t, - MatrixLayout::ColOrthoVW, - MatrixLayout::ColOrthoVW>; + = MatrixLayout::ColOrthoVW; // Large dim not mma friendly using LargeDimMatrixLayout @@ -179,15 +178,14 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires input duplication. using MmaLayout = RegisterLayout::MmaInput; - // Fragments will keep storage register layout. - // No post-load / pre-store xform - // May require pre-mma xform + // Fragments will keep storage layout using FragmentLayout = StorageLayout; }; @@ -201,6 +199,7 @@ namespace rocwmma // Vector size properties constexpr static uint32_t MaxVW = detail:: MaxVWSelector::Result; + constexpr static uint32_t VW = is_same_v || BlockDim > 32 ? MaxVW : 1u; @@ -209,10 +208,8 @@ namespace rocwmma // Matrix Layouts // Small dim mma friendly - using SmallDimMatrixLayout - = conditional_t, - MatrixLayout::RowOrthoVW, - MatrixLayout::RowOrthoVW>; + using SmallDimMatrixLayout = + MatrixLayout::RowOrthoVW; // Large dim not mma friendly using LargeDimMatrixLayout @@ -227,6 +224,7 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires input duplication. using MmaLayout = RegisterLayout::MmaInput; // Fragments will keep storage register layout. - // No post-load / pre-store xform - // May require pre-mma xform using FragmentLayout = StorageLayout; }; @@ -250,6 +246,7 @@ namespace rocwmma // Vector size properties constexpr static uint32_t MaxVW = detail:: MaxVWSelector::Result; + constexpr static uint32_t VW = is_same_v ? MaxVW : 1u; // DataLayout @@ -257,14 +254,13 @@ namespace rocwmma // Always mma friendly using MatrixLayout - = conditional_t, - MatrixLayout::RowOrthoVW, - MatrixLayout::RowOrthoVW>; + = MatrixLayout::RowOrthoVW; // Register layout direct to memory storage (load / store) using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires padded acc. using MmaLayout = RegisterLayout::MmaAcc; // Fragments will keep storage register layout. - // No post-load / pre-store xform - // May require pre-mma xform. - // TODO: Ideally, should really be MmaLayout - // However, MmaAcc frags are restricted to 16/32 MmaDim. - // Once restriction is lifted, should be adjusted. using FragmentLayout = StorageLayout; }; @@ -288,6 +279,7 @@ namespace rocwmma using StorageLayout = void; // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires padded acc. using MmaLayout = RegisterLayout::MmaAcc; - // Fragments will keep mma register layout. - // No pre-mma xform - using FragmentLayout = MmaLayout; + // Fragments will assume default mma register layout. + using FragmentLayout = RegisterLayout::MmaAcc; }; namespace detail @@ -361,6 +355,7 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect interleaved SOA format. + // Quirk: gfx11 requires input duplication. using MmaLayout = RegisterLayout::MmaInput; // Fragments will keep storage register layout. - // No post-load / pre-store xform - // May require pre-mma xform using FragmentLayout = StorageLayout; // Vector size properties derived from the matrix layout @@ -401,6 +394,7 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect interleaved SOA format. + // Quirk: gfx11 requires input duplication. using MmaLayout = RegisterLayout::MmaInput; // Fragments will keep storage register layout. - // No post-load / pre-store xform - // May require pre-mma xform using FragmentLayout = StorageLayout; // Vector size properties derived from the matrix layout @@ -440,6 +432,7 @@ namespace rocwmma using StorageLayout = RegisterLayout::Storage; // Register layout required for mma. Expect interleaved accum format for multiple blocks. + // Quirk: gfx11 requires padded mma acc using MmaLayout = RegisterLayout::MmaAcc; // Fragments will keep mma register layout. - // May require post-load / pre-store xform - // No pre-mma xform - using FragmentLayout = MmaLayout; + using FragmentLayout + = RegisterLayout::MmaAcc; // Vector size properties derived from the matrix layout constexpr static uint32_t MaxVW = layout_traits::MaxVectorWidth; @@ -468,6 +463,7 @@ namespace rocwmma using StorageLayout = void; // Register layout required for mma. Expect interleaved accum format for multiple blocks. + // Quirk: gfx11 requires padded mma acc using MmaLayout = RegisterLayout::MmaAcc; - // Fragments will keep mma register layout. - // No pre-mma xform - using FragmentLayout = MmaLayout; + // Fragments will keep mma interleaved layout. + using FragmentLayout = RegisterLayout::MmaAcc; }; } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 67e7c6bf..00174b61 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -202,13 +202,15 @@ namespace rocwmma } else { - return (traits::Format == Format::WMMA_ACC_GFX11); + // Acc with void datalayout will take SOA format + return (traits::Format == Format::WMMA_ACC_GFX11) + || (traits::Format == Format::SOA); } } else { return traits::is_storage - && ((traits::Format == Format::SOA) + && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) || (traits::Format == Format::SOA_INT) || (traits::Format == Format::AOS_INT)); @@ -483,7 +485,7 @@ namespace rocwmma = conditional_t; // Gfx11 MmaInput requires some additional transforms - if constexpr((bool)ROCWMMA_ARCH_GFX11 + if constexpr((bool)ROCWMMA_ARCH_GFX11 && (traits_lhs::is_mma_input || traits_rhs::is_mma_input)) { return TestCompatibleParams && TestFormatMatch; @@ -496,7 +498,7 @@ namespace rocwmma constexpr bool TestIdentityQuirks = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); - return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); + return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); } } else diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 2de7ae6b..28275eec 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -153,8 +153,8 @@ namespace rocwmma = conditional_t; return interleave<1u, storage_traits::KPerThread>(forward(v)); } - else if constexpr((traits_lhs::Format == Format::SOA || traits_lhs::Format == Format::AOS) - && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + else if constexpr((traits_lhs::Format == Format::SOA) + && (traits_rhs::Format == Format::WMMA_INPUT_GFX11)) { // Input is unpacked using VecTraits = VecTraits>; @@ -168,6 +168,43 @@ namespace rocwmma auto result = PackUtil::unpack(concat(packed, swapped)); return result; // Return by copy } + else if constexpr((traits_lhs::Format == Format::AOS) + && (traits_rhs::Format == Format::WMMA_INPUT_GFX11)) + { + + //auto toSOA = + // Input is unpacked + using VecTraits = VecTraits>; + using PackUtil = PackUtil; + + // Swap upper / lower 16's and then concatenate them + // to make sure we have each K value in each half. + // GFX11 wmma layout quirk needs the duplication. + auto packed = PackUtil::pack(v); + auto swapped = Swizzle::Swap16::exec(packed); + auto result = PackUtil::unpack(concat(packed, swapped)); + return result; // Return by copy + + } + else if constexpr((traits_lhs::Format == Format::SOA) + && (traits_rhs::Format == Format::WMMA_ACC_GFX11)) + { + // SOA format to wmma acc padded accumulator (gfx11). + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + auto accum = PackUtil::unpack(PackUtil::template pad<>(v)); + return accum; // Return by copy + } + else if constexpr((traits_lhs::Format == Format::WMMA_ACC_GFX11) + && (traits_rhs::Format == Format::SOA)) + { + // Padded wmma acc (gfx11) back to SOA format. + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + return PackUtil::template unpad<>(PackUtil::pack(v)); + } else { static_assert(0, "Register layout transform is not implemented"); @@ -190,6 +227,8 @@ namespace rocwmma using register_layout_transform = RegisterTransform_impl::register_layout_transform; + using register_layout_transform_nop = register_layout_transform; + } // namespace rocWMMA #endif // ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP diff --git a/library/include/rocwmma/internal/wmma.hpp b/library/include/rocwmma/internal/wmma.hpp index 74525c20..31ad64b1 100644 --- a/library/include/rocwmma/internal/wmma.hpp +++ b/library/include/rocwmma/internal/wmma.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -130,18 +130,13 @@ namespace rocwmma { // Inputs from outside will come in as fully packed static_assert(VecTraits::size() == VecTraitsA::size() * Traits::WmmaCount, - "WMMA input size mismatch"); - static_assert(VecTraits::size() == VecTraitsA::size() * Traits::WmmaCount, - "WMMA input size mismatch"); - static_assert(VecTraits::size() == IOTraitsAcc::PackedSize, - "WMMA input size mismatch"); - - // WMMA accumulator operates on unpacked, padded data in separate 32b elements. - // In the case of f16, what needs to happen is extend each unpacked element to 32b wide - // and shift the 16b data to the correct spot (determined by the WMMA backend). - // The nasty bit is that due of the extended 32b element size, the final accumulation vector - // is masqueraded as a 'packed' type, but with the same vector size as unpacked. - auto accum = PackUtil::template pad(PackUtil::unpack(regsC)); + "WMMA A input size mismatch"); + static_assert(VecTraits::size() == VecTraitsB::size() * Traits::WmmaCount, + "WMMA B input size mismatch"); + static_assert(VecTraits::size() == VecTraitsC::size(), + "WMMA Acc input size mismatch"); + + auto accum = regsC; // Iterate over packed WMMA inputs auto const aIt = makeVectorIterator(regsA).begin(); @@ -156,7 +151,7 @@ namespace rocwmma bIt++; } - return PackUtil::pack(PackUtil::template unpad(accum)); + return accum; } }; @@ -164,4 +159,4 @@ namespace rocwmma } // namespace rocwmma -#endif // ROCWMMA_WMMA_HPP \ No newline at end of file +#endif // ROCWMMA_WMMA_HPP diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index 4968764e..fda96c1c 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -344,6 +344,7 @@ namespace rocwmma using PreMmaA = typename IOConfigA::PreMmaXForm; using PreMmaB = typename IOConfigB::PreMmaXForm; using PreMmaAcc = typename IOConfigAcc::PreMmaXForm; + using PostMmaAcc = typename IOConfigAcc::PostMmaXForm; using PackA = typename IOConfigA::PackUtil; using PackB = typename IOConfigB::PackUtil; @@ -364,15 +365,18 @@ namespace rocwmma "Input fragment register layouts do not match"); // Gfx9 uses MFMA, gfx11 uses WMMA - using MMA = conditional_t<(bool)ROCWMMA_ARCH_GFX9, + using Mma = conditional_t<(bool)ROCWMMA_ARCH_GFX9, Mfma, Wmma>; - // Operate pre-ops on unpacked vectors - // the pack for mma inputs - (*d) = MMA::exec(PackA::pack(PreMmaA::exec(a.mAccess)), - PackB::pack(PreMmaB::exec(b.mAccess)), - PackAcc::pack(PreMmaAcc::exec(c.mAccess))); + // 1. Perform input pre-ops on A, B, Acc (unpacked) + // 2. Mma (packed) + // 3. Perform acc post-op on Acc + // 4. Pack back to register + d.mAccess = PostMmaAcc::exec( + PackAcc::unpack(Mma::exec(PackA::pack(PreMmaA::exec(a.mAccess)), + PackB::pack(PreMmaB::exec(b.mAccess)), + PackAcc::pack(PreMmaAcc::exec(c.mAccess))))); } ROCWMMA_DEVICE void synchronize_workgroup() From e3c61a32843904221ac2007956a5cd3248869ef1 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Wed, 11 Dec 2024 16:37:19 -0700 Subject: [PATCH 33/36] Fixup MmaDim calculator --- .../include/rocwmma/internal/io_layout.hpp | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 5950966c..85a4a6a9 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -298,17 +298,37 @@ namespace rocwmma { template + uint32_t MmaDim = (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED ? 32u : 16u> struct MmaDimSelector { private: - // Try to get the best interleaved VW along BlockDim axis. - static constexpr uint32_t SizeB128 = 128u >> 2u; - static constexpr uint32_t InterleaveVW = BlockDim / TestMmaDim; - static constexpr uint32_t BytesPerThread = InterleaveVW * sizeof(DataT); + // Smallest valid mma dim for mfma/wmma. + // Test MmaDim must not exceed BlockDim for valid layout. + static constexpr uint32_t MinMmaDim = 16u; + static constexpr uint32_t TestMmaDim = std::min(BlockDim, MmaDim); + + // For valid mma sizes, (BlockDim >= 16) + // Find minimum 16 byte load with MmaDim of 32 or 16 + static constexpr uint32_t MinLargeBytes = 16u; + static constexpr uint32_t DimPerThread = BlockDim / TestMmaDim; + static constexpr uint32_t BytesPerThread = DimPerThread * sizeof(DataT); + static constexpr uint32_t MmaDimResult = (BytesPerThread < MinLargeBytes ? MinMmaDim : TestMmaDim); + + // For invalid mma sizes (BlockDim < 16), we can have smaller MmaDim to increase VW. + // Try to balance DimPerThread and KPerThread by aiming to get half BlockDim bytes. + static constexpr bool SmallDim = TestMmaDim < MinMmaDim; + static constexpr uint32_t MinSmallBytes = BlockDim / 2u * sizeof(DataT); + static constexpr uint32_t SmallDimResult = (BytesPerThread < MinSmallBytes) ? + MmaDimSelector::Result : TestMmaDim; public: - static constexpr uint32_t Result = (BytesPerThread < SizeB128 ? 16u : TestMmaDim); + static constexpr uint32_t Result = SmallDim ? SmallDimResult : MmaDimResult; + }; + + template + struct MmaDimSelector + { + static constexpr uint32_t Result = 1u; }; } // namespace detail @@ -368,7 +388,7 @@ namespace rocwmma // Vector size properties derived from the matrix layout constexpr static uint32_t MaxVW = layout_traits::MaxVectorWidth; - constexpr static uint32_t VW = MaxVW; + constexpr static uint32_t VW = MaxVW; }; template Date: Tue, 17 Dec 2024 08:06:22 -0700 Subject: [PATCH 34/36] Removed WMMA_ACC_INT* formats --- .../include/rocwmma/internal/io_layout.hpp | 4 +- .../rocwmma/internal/layout/layout.hpp | 2 - .../layout/register_layout_traits_impl.hpp | 154 ++++++------ .../layout/register_layout_transforms.hpp | 223 ++++++++++++------ library/include/rocwmma/internal/wmma.hpp | 2 +- .../device/layout_traits_int.hpp | 17 +- 6 files changed, 249 insertions(+), 153 deletions(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 85a4a6a9..b64303b2 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -458,7 +458,7 @@ namespace rocwmma DataT, true, (bool)ROCWMMA_ARCH_GFX11 - ? RegisterLayout::Format::WMMA_ACC_INT_A_MAJOR_GFX11 + ? RegisterLayout::Format::WMMA_ACC_GFX11 : RegisterLayout::Format::ACC_INT_A_MAJOR>; // Fragments will keep mma register layout. @@ -489,7 +489,7 @@ namespace rocwmma DataT, true, (bool)ROCWMMA_ARCH_GFX11 - ? RegisterLayout::Format::WMMA_ACC_INT_A_MAJOR_GFX11 + ? RegisterLayout::Format::WMMA_ACC_GFX11 : RegisterLayout::Format::ACC_INT_A_MAJOR>; // Fragments will keep mma interleaved layout. diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp index d9d6cb64..64c5132d 100644 --- a/library/include/rocwmma/internal/layout/layout.hpp +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -155,8 +155,6 @@ namespace rocwmma ACC_INT_B_MAJOR = 5u, // Interleaved MmaAcc 'B' major order WMMA_INPUT_GFX11 = 6u, // Gfx11 input format WMMA_ACC_GFX11 = 7u, // Gfx11 acc format - WMMA_ACC_INT_A_MAJOR_GFX11 = 8u, // Gfx11 interleaved MmaAcc 'A' major order - WMMA_ACC_INT_B_MAJOR_GFX11 = 9u, // Gfx11 interleaved MmaAcc 'B' major order Invalid, // Invalid register format }; diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp index 00174b61..5e88b8ea 100644 --- a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -186,71 +186,53 @@ namespace rocwmma { using traits = register_layout_traits; using rocwmma::RegisterLayout::Format; - if constexpr((bool)ROCWMMA_ARCH_GFX11) + + if constexpr(traits::is_mma_input) { - if constexpr(traits::is_mma_input) + if constexpr((bool)ROCWMMA_ARCH_GFX11) { return traits::Format == Format::WMMA_INPUT_GFX11; } - else if constexpr(traits::is_mma_acc) + else if constexpr(traits::is_interleaved) { - if constexpr(traits::is_interleaved) - { - // Intermediate accumulation format for interleaved layout - return (traits::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) - || (traits::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11); - } - else - { - // Acc with void datalayout will take SOA format - return (traits::Format == Format::WMMA_ACC_GFX11) - || (traits::Format == Format::SOA); - } + return (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT); } else { - return traits::is_storage - && ((traits::Format == Format::SOA) - || (traits::Format == Format::AOS) - || (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT)); + return (traits::Format == Format::SOA) + || (traits::Format == Format::AOS); } } - else // Other archs + else if constexpr(traits::is_mma_acc) { - if constexpr(traits::is_mma_input) + if constexpr((bool)ROCWMMA_ARCH_GFX11) { - if constexpr(traits::is_interleaved) - { - return (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT); - } - else - { - return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); - } + return (traits::Format == Format::WMMA_ACC_GFX11) + || (!traits::is_interleaved && (traits::Format == Format::SOA || traits::Format == Format::AOS)) + || (traits::is_interleaved && (traits::Format == Format::ACC_INT_A_MAJOR || traits::Format == Format::ACC_INT_B_MAJOR)); } - else if constexpr(traits::is_mma_acc) + else if constexpr(traits::is_interleaved) { - if constexpr(traits::is_interleaved) - { - // Intermediate accumulation format for interleaved layout - return (traits::Format == Format::ACC_INT_A_MAJOR) - || (traits::Format == Format::ACC_INT_B_MAJOR); - } - else - { - return (traits::Format == Format::SOA) || (traits::Format == Format::AOS); - } + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::ACC_INT_A_MAJOR) + || (traits::Format == Format::ACC_INT_B_MAJOR); } else { - return traits::is_storage - && ((traits::Format == Format::SOA) || (traits::Format == Format::AOS) - || (traits::Format == Format::SOA_INT) - || (traits::Format == Format::AOS_INT)); + // Acc with void datalayout will take SOA format + return (traits::Format == Format::SOA) + || (traits::Format == Format::AOS); } } + else + { + return traits::is_storage + && ((traits::Format == Format::SOA) + || (traits::Format == Format::AOS) + || (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT)); + } } template @@ -526,39 +508,53 @@ namespace rocwmma // clang-format off using RegisterLayout::Format; constexpr bool TestOpposingFormat - = ( (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) - || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) - || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) - // gfx11 transforms - || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) - || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) - || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA) - || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS) - || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_ACC_GFX11) - || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_ACC_GFX11) - || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA) - || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11) - || (traits_lhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11 && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::WMMA_ACC_INT_A_MAJOR_GFX11 && traits_rhs::Format == Format::AOS_INT) - || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11) - || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11) - || (traits_lhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11 && traits_rhs::Format == Format::SOA_INT) - || (traits_lhs::Format == Format::WMMA_ACC_INT_B_MAJOR_GFX11 && traits_rhs::Format == Format::AOS_INT) + = ( + // Non-interleaved formats + // SOA <-> AOS + (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) + // Non-interleaved gfx11 formats + // SOA, AOS <-> WMMA input + // SOA, AOS <-> WMMA acc + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS) + // Interleaved formats + // SOA_INT <-> AOS_INT + // SOA_INT, AOS_INT <-> A-major acc fmt + // SOA_INT, AOS_INT <-> B-major acc fmt + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::AOS_INT) + // Interleaved gfx11 formats + // SOA_INT, AOS_INT <-> WMMA input + // SOA_INT, AOS_INT <-> WMMA acc + // A-major acc fmt <-> WMMA acc + // B-major acc fmt <-> WMMA acc + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::ACC_INT_B_MAJOR) ) && (traits_lhs::is_valid && traits_rhs::is_valid); // clang-format on diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp index 28275eec..b31d63a3 100644 --- a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -32,6 +32,113 @@ namespace rocwmma { + template + struct soa_int_to_aos_int + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, KPerThread>(forward(v)); + } + }; + + template + struct aos_int_to_soa_int + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, DimPerThread>(forward(v)); + } + }; + + struct to_wmma_input_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // v is unpacked + using VecTraits = VecTraits>; + using PackUtil = PackUtil; + + // Swap upper / lower 16's and then concatenate them + // to make sure we have each K value in each half. + // GFX11 wmma layout quirk needs the duplication. + auto packed = PackUtil::pack(v); + auto swapped = Swizzle::Swap16::exec(packed); + auto result = PackUtil::unpack(concat(packed, swapped)); + return result; // Return by copy + } + }; + + struct from_wmma_input_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return extractLo(v); + } + }; + + struct to_wmma_acc_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + using VecTraits = VecTraits>; + + // SOA format to wmma acc padded accumulator (gfx11). + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + auto accum = PackUtil::unpack(PackUtil::template pad<>(v)); + return accum; // Return by copy + } + }; + + struct from_wmma_acc_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + using VecTraits = VecTraits>; + + // Padded wmma acc (gfx11) back to SOA format. + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + return PackUtil::template unpad<>(PackUtil::pack(v)); + } + }; + + template + struct soa_int_to_mma_acc_int_a_major + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + if constexpr((bool)ROCWMMA_ARCH_GFX11) + { + + } + else + { + + } + return interleave<1u, DimPerThread>(forward(v)); + } + }; + + template + struct aos_int_to_mma_acc_int_a_major + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, DimPerThread>(forward(v)); + } + }; + namespace RegisterTransform_impl { using LayoutTraits_impl::matrix_layout_traits; @@ -91,113 +198,93 @@ namespace rocwmma ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) { using RegisterLayout::Format; + using storage_traits + = conditional_t; - // Non-interleaved AOS to SOA - if constexpr(traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) + // Non-interleaved + if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::SOA) { - using storage_traits - = conditional_t; return Transforms:: AosToSoa::exec( forward(v)); } else if constexpr(traits_lhs::Format == Format::SOA - && traits_rhs::Format == Format::AOS) + && traits_rhs::Format == Format::AOS) { - using storage_traits - = conditional_t; return Transforms:: SoaToAos::exec( forward(v)); } + else if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(Transforms::AosToSoa::exec(forward(v))); + } + else if constexpr(traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(forward(v)); + } + // Interleaved else if constexpr(traits_lhs::Format == Format::AOS_INT - && traits_rhs::Format == Format::SOA_INT) + && traits_rhs::Format == Format::SOA_INT) { - using storage_traits - = conditional_t; - return interleave<1u, storage_traits::DimPerThread>(forward(v)); + return aos_int_to_soa_int::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::AOS_INT) + { + return soa_int_to_aos_int::exec(forward(v)); } else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(aos_int_to_soa_int::exec(forward(v))); + } + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::AOS_INT) { - using storage_traits - = conditional_t; - return interleave<1u, storage_traits::KPerThread>(forward(v)); + return interleave<1u, 4u>(forward(v)); } else if constexpr(traits_lhs::Format == Format::AOS_INT - && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) { - using storage_traits - = conditional_t; return interleave<1u, storage_traits::KPerThread>(forward(v)); } else if constexpr(traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) { - using storage_traits - = conditional_t; - - return interleave<1u, 4u>(forward(v)); - } - else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR - && traits_rhs::Format == Format::AOS_INT) - { - using storage_traits - = conditional_t; return interleave<1u, 4u>(forward(v)); } + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) { - using storage_traits - = conditional_t; return interleave<1u, storage_traits::KPerThread>(forward(v)); } - else if constexpr((traits_lhs::Format == Format::SOA) - && (traits_rhs::Format == Format::WMMA_INPUT_GFX11)) - { - // Input is unpacked - using VecTraits = VecTraits>; - using PackUtil = PackUtil; - // Swap upper / lower 16's and then concatenate them - // to make sure we have each K value in each half. - // GFX11 wmma layout quirk needs the duplication. - auto packed = PackUtil::pack(v); - auto swapped = Swizzle::Swap16::exec(packed); - auto result = PackUtil::unpack(concat(packed, swapped)); - return result; // Return by copy - } - else if constexpr((traits_lhs::Format == Format::AOS) - && (traits_rhs::Format == Format::WMMA_INPUT_GFX11)) + else if constexpr((traits_lhs::Format == Format::SOA + || traits_lhs::Format == Format::ACC_INT_A_MAJOR + || traits_lhs::Format == Format::ACC_INT_B_MAJOR) + && (traits_rhs::Format == Format::WMMA_ACC_GFX11)) { - - //auto toSOA = - // Input is unpacked - using VecTraits = VecTraits>; - using PackUtil = PackUtil; - - // Swap upper / lower 16's and then concatenate them - // to make sure we have each K value in each half. - // GFX11 wmma layout quirk needs the duplication. - auto packed = PackUtil::pack(v); - auto swapped = Swizzle::Swap16::exec(packed); - auto result = PackUtil::unpack(concat(packed, swapped)); - return result; // Return by copy - + return to_wmma_acc_gfx11::exec(forward(v)); } - else if constexpr((traits_lhs::Format == Format::SOA) - && (traits_rhs::Format == Format::WMMA_ACC_GFX11)) + else if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::WMMA_ACC_GFX11) { - // SOA format to wmma acc padded accumulator (gfx11). - // f16 -> padded to f32 in lower 16 - // f32 -> nop - using PackUtil = PackUtil; - auto accum = PackUtil::unpack(PackUtil::template pad<>(v)); - return accum; // Return by copy + return to_wmma_acc_gfx11::exec(forward(v)); } else if constexpr((traits_lhs::Format == Format::WMMA_ACC_GFX11) - && (traits_rhs::Format == Format::SOA)) + && (traits_rhs::Format == Format::SOA + || traits_rhs::Format == Format::ACC_INT_A_MAJOR + || traits_rhs::Format == Format::ACC_INT_B_MAJOR)) { // Padded wmma acc (gfx11) back to SOA format. // f16 -> padded to f32 in lower 16 diff --git a/library/include/rocwmma/internal/wmma.hpp b/library/include/rocwmma/internal/wmma.hpp index 31ad64b1..e4667c3b 100644 --- a/library/include/rocwmma/internal/wmma.hpp +++ b/library/include/rocwmma/internal/wmma.hpp @@ -155,7 +155,7 @@ namespace rocwmma } }; -#endif // ROCWMMA_ARCH_GFX11 +#endif // ROCWMMA_ARCH_GFX11 || ROCWMMA_ARCH_GFX12 } // namespace rocwmma diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp index ce47a0ec..acd69f54 100644 --- a/test/unit/layout_traits_test/device/layout_traits_int.hpp +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -75,6 +75,21 @@ namespace rocwmma && (blockIdx.y == 0) && (blockIdx.z == 0); } + template + ROCWMMA_DEVICE constexpr void debugRegisterFormats() + { + if(isFirstThread()) + { + using traits_lhs = layout_traits; + using traits_rhs = layout_traits; + printf("testCompatibleRegisterParams: %d\n", LayoutTraits_impl::testCompatibleRegisterParams()); + printf("MmaDim: %d, MmaDim: %d\n", traits_lhs::MmaDim, traits_rhs::MmaDim); + printf("DataFormat: %d, DataFormat: %d\n", (int)traits_lhs::Format, (int)traits_rhs::Format); + printf("is_valid: %d, is_valid: %d\n", traits_lhs::is_valid, traits_rhs::is_valid); + printf("is_same_dataT: %d\n", is_same_v); + } + } + template ; using MmaAcc = RegisterLayout::MmaAcc; }; From 45cf20e55586d020ecf56e3deab338bbc713e5d7 Mon Sep 17 00:00:00 2001 From: Chris Millette Date: Tue, 17 Dec 2024 10:15:30 -0700 Subject: [PATCH 35/36] removes std min reference for hipRTC --- library/include/rocwmma/internal/io_layout.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index b64303b2..ec509e6b 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -305,7 +305,7 @@ namespace rocwmma // Smallest valid mma dim for mfma/wmma. // Test MmaDim must not exceed BlockDim for valid layout. static constexpr uint32_t MinMmaDim = 16u; - static constexpr uint32_t TestMmaDim = std::min(BlockDim, MmaDim); + static constexpr uint32_t TestMmaDim = min(BlockDim, MmaDim); // For valid mma sizes, (BlockDim >= 16) // Find minimum 16 byte load with MmaDim of 32 or 16 From 19af315db6d5a969c8506983ea3977aec4515929 Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Tue, 17 Dec 2024 19:54:36 -0700 Subject: [PATCH 36/36] Update perf_hgemm.cpp fix newline --- samples/perf_hgemm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/samples/perf_hgemm.cpp b/samples/perf_hgemm.cpp index 766f68b8..d5ca899a 100644 --- a/samples/perf_hgemm.cpp +++ b/samples/perf_hgemm.cpp @@ -944,4 +944,5 @@ int main() { gemm_test(7168, 7168, 7168, 2, 2); return 0; -} \ No newline at end of file +} +