From ad0f5b338e0c12c916c9da6f324dc5ccdac72a4a Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Fri, 4 Sep 2020 12:28:32 +0800 Subject: [PATCH 1/2] Add windows instruction and fix compilation bug --- README.md | 13 ++++ mmcv/ops/csrc/pytorch/focal_loss_cuda.cu | 82 ++++++++++++++---------- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index fe22efc91d..cb9f8bfdcc 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,19 @@ e.g., CC=clang CXX=clang++ CFLAGS='-stdlib=libc++' MMCV_WITH_OPS=1 pip install -e . ``` +If you are on Windows10, set the following environment variable before the installing command. + +```bash +set MMCV_WITH_OPS=1 +``` + +e.g., + +```bash +set MMCV_WITH_OPS=1 +pip install -e . +``` + Note: If you would like to use `opencv-python-headless` instead of `opencv-python`, e.g., in a minimum container environment or servers without GUI, you can first install it before installing MMCV to skip the installation of `opencv-python`. diff --git a/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu b/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu index 508f449ba3..f723512252 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu +++ b/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu @@ -8,17 +8,19 @@ void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target, const float alpha) { int output_size = output.numel(); int num_classes = input.size(1); - AT_ASSERTM(target.max().item() <= (long)num_classes, + AT_ASSERTM(target.max().item() <= (int64_t)num_classes, "target label should smaller or equal than num classes"); at::cuda::CUDAGuard device_guard(input.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "sigmoid_focal_loss_forward_cuda_kernel", [&] { - sigmoid_focal_loss_forward_cuda_kernel - <<>>( - output_size, input.data_ptr(), - target.data_ptr(), weight.data_ptr(), - output.data_ptr(), gamma, alpha, num_classes); + sigmoid_focal_loss_forward_cuda_kernel << >> + (output_size, input.data_ptr(), + target.data_ptr(), weight.data_ptr(), + output.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -36,11 +38,13 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "sigmoid_focal_loss_backward_cuda_kernel", [&] { - sigmoid_focal_loss_backward_cuda_kernel - <<>>( - output_size, input.data_ptr(), - target.data_ptr(), weight.data_ptr(), - grad_input.data_ptr(), gamma, alpha, num_classes); + sigmoid_focal_loss_backward_cuda_kernel << >> + (output_size, input.data_ptr(), + target.data_ptr(), weight.data_ptr(), + grad_input.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -53,17 +57,19 @@ void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target, int output_size = output.numel(); int num_classes = softmax.size(1); - AT_ASSERTM(target.max().item() <= (long)num_classes, + AT_ASSERTM(target.max().item() <= (int64_t)num_classes, "target label should smaller or equal than num classes"); at::cuda::CUDAGuard device_guard(softmax.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] { - softmax_focal_loss_forward_cuda_kernel - <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), weight.data_ptr(), - output.data_ptr(), gamma, alpha, num_classes); + softmax_focal_loss_forward_cuda_kernel << >> + (output_size, softmax.data_ptr(), + target.data_ptr(), weight.data_ptr(), + output.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -79,28 +85,34 @@ void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target, int output_size = buff.numel(); at::cuda::CUDAGuard device_guard(grad_input.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_input.scalar_type(), "softmax_focal_loss_backward_cuda1_kernel", - [&] { - softmax_focal_loss_backward_cuda1_kernel - <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), weight.data_ptr(), - buff.data_ptr(), gamma, alpha, num_classes); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.scalar_type(), + "softmax_focal_loss_backward_cuda1_" + "kernel", + [&] { + softmax_focal_loss_backward_cuda1_kernel << >> + (output_size, softmax.data_ptr(), target.data_ptr(), + weight.data_ptr(), buff.data_ptr(), gamma, alpha, + num_classes); + }); AT_CUDA_CHECK(cudaGetLastError()); output_size = grad_input.numel(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_input.scalar_type(), "softmax_focal_loss_backward_cuda2_kernel", - [&] { - softmax_focal_loss_backward_cuda2_kernel - <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), buff.data_ptr(), - grad_input.data_ptr(), num_classes); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.scalar_type(), + "softmax_focal_loss_backward_cuda2_" + "kernel", + [&] { + softmax_focal_loss_backward_cuda2_kernel << >> + (output_size, softmax.data_ptr(), target.data_ptr(), + buff.data_ptr(), grad_input.data_ptr(), + num_classes); + }); AT_CUDA_CHECK(cudaGetLastError()); } From aaa7e72dce9d39fdfe7ce2bd3a999bc383473b28 Mon Sep 17 00:00:00 2001 From: ZwwWayne Date: Fri, 4 Sep 2020 14:23:08 +0800 Subject: [PATCH 2/2] reformat codebase --- mmcv/ops/csrc/pytorch/focal_loss_cuda.cu | 82 +++++++++++------------- 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu b/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu index f723512252..c7cd215f5d 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu +++ b/mmcv/ops/csrc/pytorch/focal_loss_cuda.cu @@ -14,13 +14,11 @@ void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "sigmoid_focal_loss_forward_cuda_kernel", [&] { - sigmoid_focal_loss_forward_cuda_kernel << >> - (output_size, input.data_ptr(), - target.data_ptr(), weight.data_ptr(), - output.data_ptr(), gamma, alpha, num_classes); + sigmoid_focal_loss_forward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + target.data_ptr(), weight.data_ptr(), + output.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -38,13 +36,11 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( input.scalar_type(), "sigmoid_focal_loss_backward_cuda_kernel", [&] { - sigmoid_focal_loss_backward_cuda_kernel << >> - (output_size, input.data_ptr(), - target.data_ptr(), weight.data_ptr(), - grad_input.data_ptr(), gamma, alpha, num_classes); + sigmoid_focal_loss_backward_cuda_kernel + <<>>( + output_size, input.data_ptr(), + target.data_ptr(), weight.data_ptr(), + grad_input.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -63,13 +59,11 @@ void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] { - softmax_focal_loss_forward_cuda_kernel << >> - (output_size, softmax.data_ptr(), - target.data_ptr(), weight.data_ptr(), - output.data_ptr(), gamma, alpha, num_classes); + softmax_focal_loss_forward_cuda_kernel + <<>>( + output_size, softmax.data_ptr(), + target.data_ptr(), weight.data_ptr(), + output.data_ptr(), gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -85,34 +79,32 @@ void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target, int output_size = buff.numel(); at::cuda::CUDAGuard device_guard(grad_input.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.scalar_type(), - "softmax_focal_loss_backward_cuda1_" - "kernel", - [&] { - softmax_focal_loss_backward_cuda1_kernel << >> - (output_size, softmax.data_ptr(), target.data_ptr(), - weight.data_ptr(), buff.data_ptr(), gamma, alpha, - num_classes); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), + "softmax_focal_loss_backward_cuda1_" + "kernel", + [&] { + softmax_focal_loss_backward_cuda1_kernel + <<>>( + output_size, softmax.data_ptr(), + target.data_ptr(), weight.data_ptr(), + buff.data_ptr(), gamma, alpha, num_classes); + }); AT_CUDA_CHECK(cudaGetLastError()); output_size = grad_input.numel(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.scalar_type(), - "softmax_focal_loss_backward_cuda2_" - "kernel", - [&] { - softmax_focal_loss_backward_cuda2_kernel << >> - (output_size, softmax.data_ptr(), target.data_ptr(), - buff.data_ptr(), grad_input.data_ptr(), - num_classes); - }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), + "softmax_focal_loss_backward_cuda2_" + "kernel", + [&] { + softmax_focal_loss_backward_cuda2_kernel + <<>>( + output_size, softmax.data_ptr(), + target.data_ptr(), buff.data_ptr(), + grad_input.data_ptr(), num_classes); + }); AT_CUDA_CHECK(cudaGetLastError()); }