Skip to content

Commit

Permalink
NMS code cleanup (#2907)
Browse files Browse the repository at this point in the history
* Clean up and refactor ROIAlign implementation:
- Remove primitive const declaration from method names.
- Remove unnecessary headers.
- Aligning method names between cpu and cuda.

* Adding back include for cpu.

* Restoring method names of private methods to avoid conflicts.

* Restore include headers.
  • Loading branch information
datumbox authored Oct 30, 2020
1 parent c9d9e67 commit 455cd57
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 9 deletions.
4 changes: 2 additions & 2 deletions torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor");
TORCH_CHECK(
Expand Down Expand Up @@ -72,7 +72,7 @@ at::Tensor nms_cpu_kernel(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cpu(
VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
double iou_threshold);

VISION_API at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/cuda/nms_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ __device__ inline bool devIoU(T const* const a, T const* const b, const float th

template <typename T>
__global__ void nms_kernel(
const int n_boxes,
const float iou_threshold,
int n_boxes,
double iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
Expand Down Expand Up @@ -70,7 +70,7 @@ __global__ void nms_kernel(

at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor");
TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor");

Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ VISION_API at::Tensor PSROIAlign_backward_cuda(
VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold);
double iou_threshold);

VISION_API at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/csrc/nms.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once

#include "cpu/vision_cpu.h"

#ifdef WITH_CUDA
Expand All @@ -14,7 +15,7 @@
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::nms", "")
.typed<decltype(nms)>();
Expand All @@ -25,7 +26,7 @@ at::Tensor nms(
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return nms(
at::autocast::cached_cast(at::kFloat, dets),
Expand Down

0 comments on commit 455cd57

Please sign in to comment.