From a9e4b68a2fe16d00e95321c895414523ab3e7c93 Mon Sep 17 00:00:00 2001 From: Winters Montagne <118546135+WintersMontagne10335@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:09:54 +0800 Subject: [PATCH] speed up strided_copy_kernel (#58033) --- paddle/phi/kernels/gpu/strided_copy_kernel.cu | 2076 ++++++++--------- 1 file changed, 995 insertions(+), 1081 deletions(-) diff --git a/paddle/phi/kernels/gpu/strided_copy_kernel.cu b/paddle/phi/kernels/gpu/strided_copy_kernel.cu index 65dae3fc89efe..e72eca2f936e1 100644 --- a/paddle/phi/kernels/gpu/strided_copy_kernel.cu +++ b/paddle/phi/kernels/gpu/strided_copy_kernel.cu @@ -48,6 +48,127 @@ __global__ void StridedCopyFunc( } } +template +__global__ void StridedCopyCaseZeroFunc( + const T* input_data, + phi::Array input_stride, + T* output_data, + phi::Array output_stride) { + int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + int64_t output_offset = input_offset; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; + output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void StridedCopyCaseOneFunc( + const T* input_data, + phi::Array input_stride, + T* out_data, + phi::Array output_stride, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + int64_t output_offset = input_offset; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + output_offset += coordinate[N - 1 - dim] * output_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template __global__ void Strided2ContiguousFunc( const T* input_data, @@ -71,6 +192,123 @@ __global__ void Strided2ContiguousFunc( } } +template +__global__ void Strided2ContiguousCaseZeroFunc( + const T* input_data, + phi::Array input_stride, + T* output_data) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + input_offset += coordinate[RANK - 1 - dim] * input_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void Strided2ContiguousCaseOneFunc( + const T* input_data, + phi::Array input_stride, + T* out_data, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = 0; + int64_t output_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + input_offset += coordinate[N - 1 - dim] * input_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template __global__ void Contiguous2StridedFunc( const T* input_data, @@ -94,6 +332,123 @@ __global__ void Contiguous2StridedFunc( } } +template +__global__ void Contiguous2StridedCaseZeroFunc( + const T* input_data, + T* output_data, + phi::Array output_stride) { + int64_t input_offset = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + blockDim.z * blockDim.y * blockDim.x + + threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x; + int64_t output_offset = 0; + float coordinate[6] = {threadIdx.x, + threadIdx.y, + threadIdx.z, + blockIdx.x, + blockIdx.y, + blockIdx.z}; + +#pragma unroll + for (int dim = RANK - 1; dim >= 0; --dim) { + output_offset += coordinate[RANK - 1 - dim] * output_stride[dim]; + } + + output_data[output_offset] = input_data[input_offset]; +} + +template +__global__ void Contiguous2StridedCaseOneFunc( + const T* input_data, + T* out_data, + phi::Array output_stride, + phi::Array dims, + const int64_t x_max) { + int64_t x = blockIdx.x * blockDim.x + threadIdx.x; + if (x < x_max) { + int64_t input_offset = (blockIdx.z * gridDim.y + blockIdx.y) * x_max + x; + int64_t output_offset = 0; + + int64_t reg_dims[6] = { + dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]}; + int64_t coordinate[phi::DDim::kMaxRank + 1]; + + switch (N) { + case 1: + coordinate[0] = x % reg_dims[0]; + break; + case 2: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + break; + case 3: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + break; + case 4: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + break; + case 5: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + break; + case 6: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + break; + case 7: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + break; + case 8: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + break; + case 9: + coordinate[0] = x % reg_dims[0]; + coordinate[1] = x / reg_dims[0] % reg_dims[1]; + coordinate[2] = x / (reg_dims[0] * reg_dims[1]); + coordinate[3] = blockIdx.y % reg_dims[2]; + coordinate[4] = blockIdx.y / reg_dims[2] % reg_dims[3]; + coordinate[5] = blockIdx.y / (reg_dims[2] * reg_dims[3]); + coordinate[6] = blockIdx.z % reg_dims[4]; + coordinate[7] = blockIdx.z / reg_dims[4] % reg_dims[5]; + coordinate[8] = blockIdx.z / (reg_dims[4] * reg_dims[5]); + break; + } + +#pragma unroll + for (int dim = N - 1; dim >= 0; --dim) { + output_offset += coordinate[N - 1 - dim] * output_stride[dim]; + } + + out_data[output_offset] = input_data[input_offset]; + } +} + template void StridedCopyKernel(const Context& dev_ctx, const DenseTensor& input, @@ -145,8 +500,6 @@ void StridedCopyKernel(const Context& dev_ctx, } auto numel = input.numel(); - int64_t block = 512; - int64_t grid = (numel + block - 1) / block; if (numel == 1) { #ifdef PADDLE_WITH_HIP @@ -164,1088 +517,649 @@ void StridedCopyKernel(const Context& dev_ctx, return; } - if (input.meta().is_contiguous()) { - switch (input_rank) { - case 1: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - Contiguous2StridedFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - input_rank)); + dim3 grid(1, 1, 1), block(1, 1, 1); + int rank = input_rank; + int tmp = 1; + + for (int i = 0; i < 3 && i < rank; i++) { + tmp *= input_dims[rank - 1 - i]; + } + + if (rank <= 6 && tmp <= 1024 && + (input_dims.size() < 3 || input_dims[rank - 3] <= 64)) { + if (rank >= 1) { + block.x = input_dims[rank - 1]; } - } else if (out->meta().is_contiguous()) { - switch (output_rank) { - case 1: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - Strided2ContiguousFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); + + if (rank >= 2) { + block.y = input_dims[rank - 2]; + } + + if (rank >= 3) { + block.z = input_dims[rank - 3]; + } + + if (input.meta().is_contiguous()) { + switch (rank) { + case 1: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 2: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 3: + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + Contiguous2StridedCaseZeroFunc + <<>>( + input_data, output_data, output_stride); + break; + } + } else if (out->meta().is_contiguous()) { + switch (rank) { + case 1: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 2: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 3: + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 4: + grid.x = input_dims[rank - 4]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + Strided2ContiguousCaseZeroFunc + <<>>( + input_data, input_stride, output_data); + break; + } + } else { + switch (rank) { + case 1: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 2: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 3: + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 4: + grid.x = input_dims[rank - 4]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 5: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + case 6: + grid.x = input_dims[rank - 4]; + grid.y = input_dims[rank - 5]; + grid.z = input_dims[rank - 6]; + StridedCopyCaseZeroFunc<<>>( + input_data, input_stride, output_data, output_stride); + break; + } } } else { - switch (input_rank) { - case 1: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 2: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 3: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 4: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 5: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 6: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 7: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 8: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - case 9: { - switch (output_rank) { - case 1: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 2: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 3: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 4: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 5: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 6: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 7: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 8: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - case 9: - StridedCopyFunc - <<>>(input_data, - input_dims, - input_stride, - output_data, - output_dims, - output_stride, - numel); - break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of output should be less than 9, but received %d.", - output_rank)); - } - } break; - default: - PADDLE_THROW(phi::errors::InvalidArgument( - "The rank of input should be less than 9, but received %d.", - input_rank)); + phi::Array cur_input_dims; + block.x = 512; + + if (input.meta().is_contiguous()) { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Contiguous2StridedCaseOneFunc + <<>>( + input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Contiguous2StridedCaseOneFunc + <<>>(input_data, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } + } else if (out->meta().is_contiguous()) { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Strided2ContiguousCaseOneFunc + <<>>( + input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + Strided2ContiguousCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + cur_input_dims, + input_dims[rank - 1] * + input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } + } else { + switch (rank) { + case 1: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + StridedCopyCaseOneFunc + <<>>(input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1]); + break; + case 2: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2]); + break; + case 3: + grid.x = (numel + block.x - 1) / block.x; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 4: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 5: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 6: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 7: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 8: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + case 9: + grid.x = (input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3] + + block.x - 1) / + block.x; + grid.y = input_dims[rank - 4] * input_dims[rank - 5] * + input_dims[rank - 6]; + grid.z = input_dims[rank - 7] * input_dims[rank - 8] * + input_dims[rank - 9]; + cur_input_dims[0] = input_dims[rank - 1]; + cur_input_dims[1] = input_dims[rank - 2]; + cur_input_dims[2] = input_dims[rank - 4]; + cur_input_dims[3] = input_dims[rank - 5]; + cur_input_dims[4] = input_dims[rank - 7]; + cur_input_dims[5] = input_dims[rank - 8]; + StridedCopyCaseOneFunc<<>>( + input_data, + input_stride, + output_data, + output_stride, + cur_input_dims, + input_dims[rank - 1] * input_dims[rank - 2] * + input_dims[rank - 3]); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 9, but received %d.", + rank)); + } } } }