Skip to content

Commit

Permalink
add __restrict__
Browse files Browse the repository at this point in the history
  • Loading branch information
kwea123 committed Aug 30, 2022
1 parent 4292895 commit 6b2a669
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions models/csrc/intersection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ __global__ void ray_aabb_intersect_kernel(
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> centers,
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> half_sizes,
const int max_hits,
int* hit_cnt,
int* __restrict__ hit_cnt,
torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> hits_t,
torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> hits_voxel_idx
){
Expand Down Expand Up @@ -127,7 +127,7 @@ __global__ void ray_sphere_intersect_kernel(
const torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> centers,
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> radii,
const int max_hits,
int* hit_cnt,
int* __restrict__ hit_cnt,
torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> hits_t,
torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> hits_sphere_idx
){
Expand Down
14 changes: 7 additions & 7 deletions models/csrc/losses.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

template <typename scalar_t>
__global__ void prefix_sums_kernel(
const scalar_t* ws,
const scalar_t* wts,
const scalar_t* __restrict__ ws,
const scalar_t* __restrict__ wts,
const torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> rays_a,
scalar_t* ws_inclusive_scan,
scalar_t* ws_exclusive_scan,
scalar_t* wts_inclusive_scan,
scalar_t* wts_exclusive_scan
scalar_t* __restrict__ ws_inclusive_scan,
scalar_t* __restrict__ ws_exclusive_scan,
scalar_t* __restrict__ wts_inclusive_scan,
scalar_t* __restrict__ wts_exclusive_scan
){
const int n = blockIdx.x * blockDim.x + threadIdx.x;
if (n >= rays_a.size(0)) return;
Expand Down Expand Up @@ -43,7 +43,7 @@ __global__ void prefix_sums_kernel(

template <typename scalar_t>
__global__ void distortion_loss_fw_kernel(
const scalar_t* _loss,
const scalar_t* __restrict__ _loss,
const torch::PackedTensorAccessor64<int64_t, 2, torch::RestrictPtrTraits> rays_a,
torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> loss
){
Expand Down
2 changes: 1 addition & 1 deletion models/csrc/volumerendering.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ __global__ void composite_train_bw_kernel(
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_ddepth,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> dL_drgb,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> dL_dws,
scalar_t* dL_dws_times_ws,
scalar_t* __restrict__ dL_dws_times_ws,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> sigmas,
const torch::PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t> rgbs,
const torch::PackedTensorAccessor<scalar_t, 1, torch::RestrictPtrTraits, size_t> deltas,
Expand Down

0 comments on commit 6b2a669

Please sign in to comment.