diff --git a/paddle/phi/kernels/gpu/strided_copy_kernel.cu b/paddle/phi/kernels/gpu/strided_copy_kernel.cu index e72eca2f936e19..65dae3fc89efe9 100644 --- a/paddle/phi/kernels/gpu/strided_copy_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_copy_kernel.cu @@ -48,127 +48,6 @@ __global__ void StridedCopyFunc( } } -template -__global__ void StridedCopyCaseZeroFunc( - const T* input_data, - phi::Array input_stride, - T* output_data, - phi::Array output_stride) { - int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - int64_t output_offset = input_offset; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; - output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void StridedCopyCaseOneFunc( - const T* input_data, - phi::Array input_stride, - T* out_data, - phi::Array output_stride, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - int64_t output_offset = input_offset; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - input_offset += coordinate[N - 1 - dim] * input_stride[dim]; - output_offset += coordinate[N - 1 - dim] * output_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template __global__ void Strided2ContiguousFunc( const T* input_data, @@ -192,123 +71,6 @@ __global__ void Strided2ContiguousFunc( } } -template -__global__ void Strided2ContiguousCaseZeroFunc( - const T* input_data, - phi::Array input_stride, - T* output_data) { - int64_t input_offset = 0; - int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void Strided2ContiguousCaseOneFunc( - const T* input_data, - phi::Array input_stride, - T* out_data, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = 0; - int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - input_offset += coordinate[N - 1 - dim] * input_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template __global__ void Contiguous2StridedFunc( const T* input_data, @@ -332,123 +94,6 @@ __global__ void Contiguous2StridedFunc( } } -template -__global__ void Contiguous2StridedCaseZeroFunc( - const T* input_data, - T* output_data, - phi::Array output_stride) { - int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - blockDim.z * blockDim.y * blockDim.x + - threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x; - int64_t output_offset = 0; - float coordinate[6] = {threadIdx.x, - threadIdx.y, - threadIdx.z, - blockIdx.x, - blockIdx.y, - blockIdx.z}; - -#pragma unroll - for (int dim = RANK - 1; dim >= 0; --dim) { - output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; - } - - output_data[output_offset] = input_data[input_offset]; -} - -template -__global__ void Contiguous2StridedCaseOneFunc( - const T* input_data, - T* out_data, - phi::Array output_stride, - phi::Array dims, - const int64_t x_max) { - int64_t x = blockIdx.x * blockDim.x + threadIdx.x; - if (x < x_max) { - int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; - int64_t output_offset = 0; - - int64_t reg_dims[6] = { - dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; - int64_t coordinate[phi::DDim::kMaxRank + 1]; - - switch (N) { - case 1: - coordinate[0] = x % reg_dims[0]; - break; - case 2: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - break; - case 3: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - break; - case 4: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - break; - case 5: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - break; - case 6: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - break; - case 7: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - break; - case 8: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - break; - case 9: - coordinate[0] = x % reg_dims[0]; - coordinate[1] = x / reg_dims[0] % reg_dims[1]; - coordinate[2] = x / (reg_dims[0] * reg_dims[1]); - coordinate[3] = blockIdx.y % reg_dims[2]; - coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; - coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); - coordinate[6] = blockIdx.z % reg_dims[4]; - coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; - coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); - break; - } - -#pragma unroll - for (int dim = N - 1; dim >= 0; --dim) { - output_offset += coordinate[N - 1 - dim] * output_stride[dim]; - } - - out_data[output_offset] = input_data[input_offset]; - } -} - template void StridedCopyKernel(const Context& dev_ctx, const DenseTensor& input, @@ -500,6 +145,8 @@ void StridedCopyKernel(const Context& dev_ctx, } auto numel = input.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; if (numel == 1) { #ifdef PADDLE_WITH_HIP @@ -517,649 +164,1088 @@ void StridedCopyKernel(const Context& dev_ctx, return; } - dim3 grid(1, 1, 1), block(1, 1, 1); - int rank = input_rank; - int tmp = 1; - - for (int i = 0; i < 3 && i < rank; i++) { - tmp *= input_dims[rank - 1 - i]; - } - - if (rank <= 6 && tmp <= 1024 && - (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { - if (rank >= 1) { - block.x = input_dims[rank - 1]; - } - - if (rank >= 2) { - block.y = input_dims[rank - 2]; - } - - if (rank >= 3) { - block.z = input_dims[rank - 3]; + if (input.meta().is_contiguous()) { + switch (input_rank) { + case 1: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + Contiguous2StridedFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + input_rank)); } - - if (input.meta().is_contiguous()) { - switch (rank) { - case 1: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 2: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 3: - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 4: - grid.x = input_dims[rank - 4]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - Contiguous2StridedCaseZeroFunc - <<>>( - input_data, output_data, output_stride); - break; - } - } else if (out->meta().is_contiguous()) { - switch (rank) { - case 1: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 2: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 3: - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 4: - grid.x = input_dims[rank - 4]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - Strided2ContiguousCaseZeroFunc - <<>>( - input_data, input_stride, output_data); - break; - } - } else { - switch (rank) { - case 1: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 2: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 3: - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 4: - grid.x = input_dims[rank - 4]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 5: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - case 6: - grid.x = input_dims[rank - 4]; - grid.y = input_dims[rank - 5]; - grid.z = input_dims[rank - 6]; - StridedCopyCaseZeroFunc<<>>( - input_data, input_stride, output_data, output_stride); - break; - } + } else if (out->meta().is_contiguous()) { + switch (output_rank) { + case 1: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + Strided2ContiguousFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); } } else { - phi::Array cur_input_dims; - block.x = 512; - - if (input.meta().is_contiguous()) { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Contiguous2StridedCaseOneFunc - <<>>( - input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Contiguous2StridedCaseOneFunc - <<>>(input_data, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } - } else if (out->meta().is_contiguous()) { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Strided2ContiguousCaseOneFunc - <<>>( - input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - Strided2ContiguousCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - cur_input_dims, - input_dims[rank - 1] * - input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } - } else { - switch (rank) { - case 1: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - StridedCopyCaseOneFunc - <<>>(input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1]); - break; - case 2: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2]); - break; - case 3: - grid.x = (numel + block.x - 1) / block.x; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 4: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 5: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 6: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 7: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 8: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - case 9: - grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3] + - block.x - 1) / - block.x; - grid.y = input_dims[rank - 4] * input_dims[rank - 5] * - input_dims[rank - 6]; - grid.z = input_dims[rank - 7] * input_dims[rank - 8] * - input_dims[rank - 9]; - cur_input_dims[0] = input_dims[rank - 1]; - cur_input_dims[1] = input_dims[rank - 2]; - cur_input_dims[2] = input_dims[rank - 4]; - cur_input_dims[3] = input_dims[rank - 5]; - cur_input_dims[4] = input_dims[rank - 7]; - cur_input_dims[5] = input_dims[rank - 8]; - StridedCopyCaseOneFunc<<>>( - input_data, - input_stride, - output_data, - output_stride, - cur_input_dims, - input_dims[rank - 1] * input_dims[rank - 2] * - input_dims[rank - 3]); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - rank)); - } + switch (input_rank) { + case 1: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 2: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 3: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 4: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 5: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 6: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 7: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 8: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + case 9: { + switch (output_rank) { + case 1: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 2: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 3: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 4: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 5: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 6: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 7: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 8: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + case 9: + StridedCopyFunc + <<>>(input_data, + input_dims, + input_stride, + output_data, + output_dims, + output_stride, + numel); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of output should be less than 9, but received %d.", + output_rank)); + } + } break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + input_rank)); } } }