Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update for CUDA 10.2 #42

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions pointnet2/src/ball_query.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "ball_query_gpu.h"

extern THCState *state;

#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)

int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
Expand All @@ -19,7 +18,7 @@ int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,
const float *xyz = xyz_tensor.data<float>();
int *idx = idx_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
return 1;
}
}
12 changes: 5 additions & 7 deletions pointnet2/src/group_points.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <THC/THC.h>
#include "group_points_gpu.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>

extern THCState *state;


int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample,
Expand All @@ -15,8 +15,7 @@ int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample
const int *idx = idx_tensor.data<int>();
const float *grad_out = grad_out_tensor.data<float>();

cudaStream_t stream = THCState_getCurrentStream(state);

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
return 1;
}
Expand All @@ -29,8 +28,7 @@ int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample,
const int *idx = idx_tensor.data<int>();
float *out = out_tensor.data<float>();

cudaStream_t stream = THCState_getCurrentStream(state);

cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
return 1;
}
}
13 changes: 6 additions & 7 deletions pointnet2/src/interpolate.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#include <torch/serialize/tensor.h>
#include <vector>
#include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"

extern THCState *state;
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>


void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
Expand All @@ -18,7 +17,7 @@ void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
float *dist2 = dist2_tensor.data<float>();
int *idx = idx_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
}

Expand All @@ -34,7 +33,7 @@ void three_interpolate_wrapper_fast(int b, int c, int m, int n,
float *out = out_tensor.data<float>();
const int *idx = idx_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
}

Expand All @@ -49,6 +48,6 @@ void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
float *grad_points = grad_points_tensor.data<float>();
const int *idx = idx_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
}
}
11 changes: 5 additions & 6 deletions pointnet2/src/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#include <torch/serialize/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <vector>
#include <THC/THC.h>

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include "sampling_gpu.h"

extern THCState *state;


int gather_points_wrapper_fast(int b, int c, int n, int npoints,
Expand All @@ -14,7 +13,7 @@ int gather_points_wrapper_fast(int b, int c, int n, int npoints,
const int *idx = idx_tensor.data<int>();
float *out = out_tensor.data<float>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
return 1;
}
Expand All @@ -27,7 +26,7 @@ int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints,
const int *idx = idx_tensor.data<int>();
float *grad_points = grad_points_tensor.data<float>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
return 1;
}
Expand All @@ -40,7 +39,7 @@ int furthest_point_sampling_wrapper(int b, int n, int m,
float *temp = temp_tensor.data<float>();
int *idx = idx_tensor.data<int>();

cudaStream_t stream = THCState_getCurrentStream(state);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
return 1;
}