Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Support for running on arbitrary CUDA device. #537

Merged
merged 8 commits into from
Mar 26, 2019
Merged
4 changes: 4 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -263,6 +264,8 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");

at::cuda::CUDAGuard device_guard(input.device());

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
Expand Down Expand Up @@ -311,6 +314,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
const int sampling_ratio) {
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(grad.device());

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
Expand Down
4 changes: 4 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -115,6 +116,8 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");

at::cuda::CUDAGuard device_guard(input.device());

auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
Expand Down Expand Up @@ -167,6 +170,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
// TODO add more checks
at::cuda::CUDAGuard device_guard(grad.device());

auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
Expand Down
7 changes: 6 additions & 1 deletion maskrcnn_benchmark/csrc/cuda/SigmoidFocalLoss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// cyfu@cs.unc.edu
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
Expand Down Expand Up @@ -111,6 +112,8 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
AT_ASSERTM(targets.type().is_cuda(), "targets must be a CUDA tensor");
AT_ASSERTM(logits.dim() == 2, "logits should be NxClass");

at::cuda::CUDAGuard device_guard(logits.device());

const int num_samples = logits.size(0);

auto losses = at::empty({num_samples, logits.size(1)}, logits.options());
Expand Down Expand Up @@ -156,7 +159,9 @@ at::Tensor SigmoidFocalLoss_backward_cuda(

const int num_samples = logits.size(0);
AT_ASSERTM(logits.size(1) == num_classes, "logits.size(1) should be num_classes");


at::cuda::CUDAGuard device_guard(logits.device());

auto d_logits = at::zeros({num_samples, num_classes}, logits.options());
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/csrc/cuda/nms.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGuard.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>
Expand Down Expand Up @@ -70,6 +71,8 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes.device());

auto scores = boxes.select(1, 4);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
Expand Down