Skip to content

Commit

Permalink
[Feature] : Add Deformable Conv2d TensorRT Plugin (#858)
Browse files Browse the repository at this point in the history
* add dcn tensorrt plugin

* prepare for fp16 support

* fix for lint

* limit column buffer

* add docstring to memcpyPermute
  • Loading branch information
grimoire authored Mar 11, 2021
1 parent 57f3a61 commit 9ba1f76
Show file tree
Hide file tree
Showing 9 changed files with 763 additions and 3 deletions.
9 changes: 7 additions & 2 deletions mmcv/ops/csrc/deform_conv_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,16 @@
#ifndef DEFORM_CONV_CUDA_KERNEL_CUH
#define DEFORM_CONV_CUDA_KERNEL_CUH

#include <float.h>
#ifdef MMCV_WITH_TRT
#include "common_cuda_helper.hpp"
#else // MMCV_WITH_TRT
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#else // MMCV_USE_PARROTS
#include "pytorch_cuda_helper.hpp"
#endif
#endif // MMCV_USE_PARROTS
#endif // MMCV_WITH_TRT

template <typename T>
__device__ T deformable_im2col_bilinear(const T *input, const int data_width,
Expand Down
66 changes: 66 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_cuda_helper.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"

using mmcv::TensorDesc;

template <class scalar_t>
__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n,
TensorDesc ts_src_stride,
TensorDesc ts_dst_stride,
TensorDesc ts_permute) {
const int src_dim = ts_src_stride.dim;
int *src_stride = &(ts_src_stride.stride[0]);
int *dst_stride = &(ts_dst_stride.stride[0]);
int *permute = &(ts_permute.shape[0]);
CUDA_1D_KERNEL_LOOP(index, n) {
size_t dst_index = index;
size_t src_index = 0;
for (int i = 0; i < src_dim; ++i) {
int dim_index = dst_index / dst_stride[i];
dst_index = dst_index % dst_stride[i];
src_index += dim_index * src_stride[permute[i]];
}
dst[index] = src[src_index];
}
}

template <class scalar_t>
void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size,
int *permute, int src_dim, cudaStream_t stream) {
size_t copy_size = 1;
TensorDesc ts_permute;
memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int));

TensorDesc ts_src_stride;
TensorDesc ts_dst_stride;
ts_src_stride.dim = src_dim;
ts_dst_stride.dim = src_dim;
int *src_stride = &(ts_src_stride.stride[0]);
int *dst_stride = &(ts_dst_stride.stride[0]);
int *dst_size = &(ts_dst_stride.shape[0]);
src_stride[src_dim - 1] = 1;
dst_stride[src_dim - 1] = 1;

for (int i = src_dim - 1; i >= 0; --i) {
dst_size[i] = src_size[permute[i]];
if (i < src_dim - 1) {
src_stride[i] = src_stride[i + 1] * src_size[i + 1];
}
}

for (int i = src_dim - 1; i >= 0; --i) {
copy_size *= dst_size[i];
if (i < src_dim - 1) {
dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1];
}
}

copy_permute_kernel<scalar_t>
<<<GET_BLOCKS(copy_size), THREADS_PER_BLOCK, 0, stream>>>(
dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute);
}

template void memcpyPermute<float>(float *dst, const float *src, int *src_size,
int *permute, int src_dim,
cudaStream_t stream);
Loading

0 comments on commit 9ba1f76

Please sign in to comment.