diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 0a0121b660ef..18679f86e82c 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1,4 +1,3 @@ - #include #include #include @@ -327,71 +326,193 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, ///////////////////////////////////////////// -using half8 = __attribute__((__vector_size__(4 * sizeof(float)))) float; - -/*template -__device__ __forceinline__ T loadnt(T* addr) { - return __builtin_nontemporal_load(addr); - //return *((T*)addr); -}*/ - -#define THRDS 64 -#define YTILE 2 -#define WvPrGrp 16 -#define A_CHUNK 8 -#define UNRL 2 -#define M 1 #define DTYPE half -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +__device__ __forceinline__ int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} -__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- __shared__ half s[1024 * 32]; - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + // uint32_t commitColumn[YTILE]; + // for (uint32_t i = 0; i < YTILE; i++) { + // commitColumn[i] = 1; + //} + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + if (threadIdx.y >= _WvPrGrp) return; + float sum[M][YTILE]; + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- for (int i = 0; i < YTILE; i++) for (int m = 0; m < M; m++) sum[m][i] = 0; bigType bigA[M][UNRL]; bigType bigB0[UNRL]; - #if (YTILE >= 2) bigType bigB1[UNRL]; - #endif + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - #if (YTILE >= 2) - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - #endif + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } + // Fetch activation matrix from either just LDS or from both LDS / memory #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { @@ -400,8 +521,12 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; // Fetch A activation matrix in interleaved fashion from LDS or memory + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); } } @@ -411,10 +536,10 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - #pragma unroll - for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t m = 0; m < M; m++) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -424,11 +549,34 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - #if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - #endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } @@ -459,38 +607,50 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); } } - if (threadIdx.x == 63) { for (int m = 0; m < M; m++) { for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); C[n + i + m * N] = __float2half(sum[m][i]); } } } - n += CuCount * WvPrGrp * YTILE; + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} } } - -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount){UNREACHABLE_CODE} - +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support - -__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; @@ -511,12 +671,15 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, commitColumn[i] = 1; } + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -547,11 +710,14 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + if (threadIdx.y >= _WvPrGrp) return; + float sum[M][YTILE]; //---------------------------------------------------- @@ -582,36 +748,14 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; - #if (YTILE >= 2) bigType bigB1[UNRL]; - #endif - #if (YTILE >= 3) bigType bigB2[UNRL]; - #endif - #if (YTILE >= 4) bigType bigB3[UNRL]; - #endif - #if (YTILE >= 5) bigType bigB4[UNRL]; - #endif - #if (YTILE >= 6) bigType bigB5[UNRL]; - #endif - #if (YTILE >= 7) bigType bigB6[UNRL]; - #endif - #if (YTILE >= 8) bigType bigB7[UNRL]; - #endif - #if (YTILE >= 9) bigType bigB8[UNRL]; - #endif - #if (YTILE >= 10) - bigType bigB9[UNRL]; - #endif - #if (YTILE >= 11) - bigType bigB10[UNRL]; - #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -637,51 +781,18 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - #if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - #endif - #if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - #endif - #if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - #endif - #if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - #endif - #if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - #endif - #if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - #endif - #if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - #endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -703,14 +814,14 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; + for (uint32_t m = 0; m < M; m++) { #pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -720,56 +831,34 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - #if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - #endif - #if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - #endif - #if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - #endif - #if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - #endif - #if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - #endif - #if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - #endif - #if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - #endif - #if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); - #endif - #if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); - #endif - #if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); - #endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } @@ -809,11 +898,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } } - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); + n += CuCount * _WvPrGrp * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -827,33 +912,29 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount){UNREACHABLE_CODE} - +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#undef YTILE -#undef UNRL -#undef M - -#define YTILE 2 -#define UNRL 2 -#define M 2 - #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; -__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; - __int128_t b128; half8 h8; }; @@ -874,12 +955,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, commitColumn[i] = 1; } + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + //---------------------------------------------------- // Indexing function into the column of weight matrix B // Algorithm does 64 lane k-splitting / wave and uses // WG ID and Thread ID to find the index. //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -900,6 +985,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- + #define PCML + #ifndef PCML for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); @@ -910,10 +997,24 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, if (k_in >= min(K * M, 32 * 1024)) break; - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); float sum[M][YTILE]; @@ -932,7 +1033,13 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // - After completing first set of columns, WGs start // working on the next set of available columns //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else while (n < N) { + #endif //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation // split across 64 lanes. @@ -945,36 +1052,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; - #if (YTILE >= 2) bigType bigB1[UNRL]; - #endif - #if (YTILE >= 3) bigType bigB2[UNRL]; - #endif - #if (YTILE >= 4) bigType bigB3[UNRL]; - #endif - #if (YTILE >= 5) bigType bigB4[UNRL]; - #endif - #if (YTILE >= 6) bigType bigB5[UNRL]; - #endif - #if (YTILE >= 7) bigType bigB6[UNRL]; - #endif - #if (YTILE >= 8) bigType bigB7[UNRL]; - #endif - #if (YTILE >= 9) bigType bigB8[UNRL]; - #endif - #if (YTILE >= 10) bigType bigB9[UNRL]; - #endif - #if (YTILE >= 11) bigType bigB10[UNRL]; - #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -993,58 +1080,44 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - #if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - #endif - #if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - #endif - #if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - #endif - #if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - #endif - #if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - #endif - #if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - #endif - #if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - #endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1057,10 +1130,14 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // Fetch A activation matrix in interleaved fashion from LDS or memory for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else if (k_ + K * m < 32 * 1024) bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); else bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif } } @@ -1083,61 +1160,47 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- - #if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - #endif - #if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - #endif - #if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - #endif - #if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - #endif - #if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - #endif - #if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - #endif - #if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - #endif - #if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); - #endif - #if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); - #endif - #if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); - #endif + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); } } } } + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- @@ -1172,11 +1235,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } } - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; // Check whether there will be fragmenation! // This will happen only for the last wave! @@ -1189,781 +1249,62 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } } } - -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount){UNREACHABLE_CODE} - +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#undef YTILE -#undef UNRL -#undef M +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); -#define YTILE 5 -#define UNRL 2 -#define M 3 +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + /*wvSpltK_hf:*/ \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else { \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } \ + } -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support - -__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; - } - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - #if (YTILE >= 2) - bigType bigB1[UNRL]; - #endif - #if (YTILE >= 3) - bigType bigB2[UNRL]; - #endif - #if (YTILE >= 4) - bigType bigB3[UNRL]; - #endif - #if (YTILE >= 5) - bigType bigB4[UNRL]; - #endif - #if (YTILE >= 6) - bigType bigB5[UNRL]; - #endif - #if (YTILE >= 7) - bigType bigB6[UNRL]; - #endif - #if (YTILE >= 8) - bigType bigB7[UNRL]; - #endif - #if (YTILE >= 9) - bigType bigB8[UNRL]; - #endif - #if (YTILE >= 10) - bigType bigB9[UNRL]; - #endif - #if (YTILE >= 11) - bigType bigB10[UNRL]; - #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - #if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - #endif - #if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - #endif - #if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - #endif - #if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - #endif - #if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - #endif - #if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - #endif - #if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - #endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - #pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - #if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - #endif - #if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - #endif - #if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - #endif - #if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - #endif - #if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - #endif - #if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - #endif - #if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - #endif - #if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); - #endif - #if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); - #endif - #if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); - #endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} - -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount){UNREACHABLE_CODE} - -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -#undef YTILE -#undef UNRL -#undef M - -#define YTILE 7 -#define UNRL 1 -#define M 4 - -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support - -__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; - - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) { - commitColumn[i] = 1; - } - - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint64_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32 * 1024); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for - // bank-conflict-free readback - - if (k_in >= min(K * M, 32 * 1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- - while (n < N) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m = 0; m < M; m++) sum[m][i] = 0; - - bigType bigA[M][UNRL]; - bigType bigB0[UNRL]; - #if (YTILE >= 2) - bigType bigB1[UNRL]; - #endif - #if (YTILE >= 3) - bigType bigB2[UNRL]; - #endif - #if (YTILE >= 4) - bigType bigB3[UNRL]; - #endif - #if (YTILE >= 5) - bigType bigB4[UNRL]; - #endif - #if (YTILE >= 6) - bigType bigB5[UNRL]; - #endif - #if (YTILE >= 7) - bigType bigB6[UNRL]; - #endif - #if (YTILE >= 8) - bigType bigB7[UNRL]; - #endif - #if (YTILE >= 9) - bigType bigB8[UNRL]; - #endif - #if (YTILE >= 10) - bigType bigB9[UNRL]; - #endif - #if (YTILE >= 11) - bigType bigB10[UNRL]; - #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { - // Fetch the weight matrix from memory! - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // if (k_ >= K) break; - // bool skip = (k_ >= K); - // bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - #if (YTILE >= 2) - // if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); - #endif - #if (YTILE >= 3) - // if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); - #endif - #if (YTILE >= 4) - // if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); - #endif - #if (YTILE >= 5) - // if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); - #endif - #if (YTILE >= 6) - // if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); - #endif - #if (YTILE >= 7) - // if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); - #endif - #if (YTILE >= 8) - // if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); - #endif - /* - #if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = - (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) - continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if - (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = - (loadnt((half8*)(&B_[10 * K]))); #endif - */ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m = 0; m < M; m++) { - if (k_ + K * m < 32 * 1024) - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); - } - } - - // Do the matrix multiplication in interleaved manner - #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - #pragma unroll - for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][0]) - : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - #if (YTILE >= 2) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][1]) - : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); - #endif - #if (YTILE >= 3) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][2]) - : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); - #endif - #if (YTILE >= 4) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][3]) - : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); - #endif - #if (YTILE >= 5) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][4]) - : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); - #endif - #if (YTILE >= 6) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][5]) - : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); - #endif - #if (YTILE >= 7) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][6]) - : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); - #endif - #if (YTILE >= 8) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][7]) - : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); - #endif - #if (YTILE >= 9) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][8]) - : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); - #endif - #if (YTILE >= 10) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][9]) - : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); - #endif - #if (YTILE >= 11) - asm("v_dot2c_f32_f16 %0, %2, %3" - : "=v"(sum[m][10]) - : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); - #endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[m][y]) - : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); - } - } - - if (threadIdx.x == 63) { - for (int m = 0; m < M; m++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); - } - } - } - - n += CuCount * WvPrGrp * YTILE; - - // if (threadIdx.x == 0) - // n = atomicAdd(((unsigned int*)(C)), YTILE); - // n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if (n < N && (n + YTILE) >= N) { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) { - commitColumn[i] = 0; - } - n = startColumn; - } - } -} - -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { - UNREACHABLE_CODE -} - -#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support - -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, - const int K_in, const int N_in, cudaStream_t stream, - const int CuCount = 0) { - dim3 grid(CuCount); - dim3 block(THRDS, WvPrGrp); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto* c = reinterpret_cast(out_c); - switch (N_in) { - case 1: - if ((K_in <= 32 * 1024) && (M_in % 2 == 0)) { - wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - } else { - wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - } - break; - case 2: - wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - break; - case 3: - wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - break; - case 4: - wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, - CuCount); - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + - "," + std::to_string(K_in) + "," + - std::to_string(N_in)); - } + switch (N_in) { + case 1: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } cudaError_t err = cudaGetLastError(); if (cudaSuccess != err) { throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } -} +} \ No newline at end of file