From 455cd57c9404429448280cd4a0d7f600a51afc54 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 30 Oct 2020 10:53:51 +0000 Subject: [PATCH] NMS code cleanup (#2907) * Clean up and refactor ROIAlign implementation: - Remove primitive const declaration from method names. - Remove unnecessary headers. - Aligning method names between cpu and cuda. * Adding back include for cpu. * Restoring method names of private methods to avoid conflicts. * Restore include headers. --- torchvision/csrc/cpu/nms_cpu.cpp | 4 ++-- torchvision/csrc/cpu/vision_cpu.h | 2 +- torchvision/csrc/cuda/nms_cuda.cu | 6 +++--- torchvision/csrc/cuda/vision_cuda.h | 2 +- torchvision/csrc/nms.h | 5 +++-- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torchvision/csrc/cpu/nms_cpu.cpp b/torchvision/csrc/cpu/nms_cpu.cpp index cbb5e056c69..00a4c61db7a 100644 --- a/torchvision/csrc/cpu/nms_cpu.cpp +++ b/torchvision/csrc/cpu/nms_cpu.cpp @@ -4,7 +4,7 @@ template at::Tensor nms_cpu_kernel( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold) { + double iou_threshold) { TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK( @@ -72,7 +72,7 @@ at::Tensor nms_cpu_kernel( at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold) { + double iou_threshold) { TORCH_CHECK( dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); TORCH_CHECK( diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 69b1bbf555d..6a34945b11e 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cpu( VISION_API at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold); + double iou_threshold); VISION_API at::Tensor DeformConv2d_forward_cpu( const at::Tensor& input, diff --git a/torchvision/csrc/cuda/nms_cuda.cu b/torchvision/csrc/cuda/nms_cuda.cu index 20e090c2041..548dc2f69cb 100644 --- a/torchvision/csrc/cuda/nms_cuda.cu +++ b/torchvision/csrc/cuda/nms_cuda.cu @@ -22,8 +22,8 @@ __device__ inline bool devIoU(T const* const a, T const* const b, const float th template __global__ void nms_kernel( - const int n_boxes, - const float iou_threshold, + int n_boxes, + double iou_threshold, const T* dev_boxes, unsigned long long* dev_mask) { const int row_start = blockIdx.y; @@ -70,7 +70,7 @@ __global__ void nms_kernel( at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold) { + double iou_threshold) { TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 2481cfc63c2..0652350a01b 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cuda( VISION_API at::Tensor nms_cuda( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold); + double iou_threshold); VISION_API at::Tensor DeformConv2d_forward_cuda( const at::Tensor& input, diff --git a/torchvision/csrc/nms.h b/torchvision/csrc/nms.h index ac6fb93ea5d..aed675e5d26 100644 --- a/torchvision/csrc/nms.h +++ b/torchvision/csrc/nms.h @@ -1,4 +1,5 @@ #pragma once + #include "cpu/vision_cpu.h" #ifdef WITH_CUDA @@ -14,7 +15,7 @@ at::Tensor nms( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold) { + double iou_threshold) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("torchvision::nms", "") .typed(); @@ -25,7 +26,7 @@ at::Tensor nms( at::Tensor nms_autocast( const at::Tensor& dets, const at::Tensor& scores, - const double iou_threshold) { + double iou_threshold) { c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); return nms( at::autocast::cached_cast(at::kFloat, dets),