Skip to content

Commit

Permalink
scatter int passes basic
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 6, 2023
1 parent db05b5b commit 9d2e337
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 140 deletions.
11 changes: 6 additions & 5 deletions lib/csrc/agg/gather_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ __global__ void gather_add_int_backward_kernel(
int qi;
int ref[3],ref_p[3],nl[3],nl_p[3];
bool valid;
float weight,pix_n,pix_m;
float weight,grad,pix_m;

// -- batching --
int query_start = q_per_thread*(threadIdx.x + blockDim.x*blockIdx.x);
Expand Down Expand Up @@ -296,16 +296,17 @@ __global__ void gather_add_int_backward_kernel(
// -- time is always valid --
ref_p[0] = ref[0] + pk;
nl_p[0] = reflect_bounds ? bounds(nl[0]+pk,T) : (nl[0]+pk);
valid = (nl_p[0] >= 0) and (nl_p[0] < T) and (ref_p[0] >= 0) and (ref_p[0] < T);
valid = (nl_p[0] >= 0) and (nl_p[0] < T);
valid = valid and (ref_p[0] >= 0) and (ref_p[0] < T);
if (not valid){ continue; }

// -- num features --
for (int iftr = 0; iftr < F; iftr++){
pix_n = out_vid_grad[ibatch][ihead][ref_p[0]][iftr][ref_p[1]][ref_p[2]];
grad = out_vid_grad[ibatch][ihead][ref_p[0]][iftr][ref_p[1]][ref_p[2]];
pix_m = vid[ibatch][ihead][nl_p[0]][iftr][nl_p[1]][nl_p[2]];
atomicAdd(&in_vid_grad[ibatch][ihead][nl_p[0]][iftr][nl_p[1]][nl_p[2]],
weight*pix_n);
acc_dists_grad += pix_n*pix_m;
weight*grad);
acc_dists_grad += grad*pix_m;
}

} // pt
Expand Down
6 changes: 3 additions & 3 deletions lib/csrc/agg/scatter_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void scatter_add_forward_cuda(
const torch::Tensor in_vid,
const torch::Tensor dists, const torch::Tensor inds,
int ps, int strideIn, int strideOut, int pt,
int dilation, bool reflect_bounds, int patch_offset);
int dilation, bool reflect_bounds, int patch_offset, bool itype_int);

void scatter_add_int_backward_cuda(
torch::Tensor in_vid_grad,
Expand Down Expand Up @@ -50,15 +50,15 @@ void scatter_add_forward(
const torch::Tensor dists,
const torch::Tensor inds,
int ps, int strideIn, int strideOut, int pt,
int dilation, bool reflect_bounds, int patch_offset){
int dilation, bool reflect_bounds, int patch_offset, bool itype_int){
CHECK_INPUT(out_vid);
CHECK_INPUT(counts);
CHECK_INPUT(in_vid);
CHECK_INPUT(dists);
CHECK_INPUT(inds);
scatter_add_forward_cuda(out_vid,counts,in_vid,dists,inds,
ps,strideIn,strideOut,pt,dilation,
reflect_bounds,patch_offset);
reflect_bounds,patch_offset,itype_int);
}


Expand Down
Loading

0 comments on commit 9d2e337

Please sign in to comment.