Skip to content

Commit

Permalink
create graph opts
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Dec 8, 2023
1 parent 2e4f01f commit b8b5ca3
Show file tree
Hide file tree
Showing 25 changed files with 2,212 additions and 792 deletions.
81 changes: 16 additions & 65 deletions lib/csrc/agg/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,16 @@ void scatter_int_forward_cuda(
int ps, int pt, int dilation, int stride0,
bool reflect_bounds, int patch_offset);

void scatter_labels_cuda(
const torch::Tensor flows, const torch::Tensor flows_k,
torch::Tensor labels, torch::Tensor names,
int ws, int wt, int stride0, float stride1, bool full_ws);

void scatter_tensor_forward_cuda(torch::Tensor out_tensor,
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1, int H, int W);

void scatter_tensor_backward_cuda(torch::Tensor in_tensor_grad,
const torch::Tensor out_tensor_grad,
const torch::Tensor labels,
const torch::Tensor flows_k, int stride0);
// void scatter_tensor_forward_cuda(torch::Tensor out_tensor,
// const torch::Tensor in_tensor,
// const torch::Tensor labels,
// const torch::Tensor flows_k,
// int stride0, int stride1, int H, int W);

// void scatter_tensor_backward_cuda(torch::Tensor in_tensor_grad,
// const torch::Tensor out_tensor_grad,
// const torch::Tensor labels,
// const torch::Tensor flows_k, int stride0);

// void scatter_bilin2d_forward_cuda(
// const torch::Tensor vid, const torch::Tensor weights,
Expand Down Expand Up @@ -67,50 +62,6 @@ void scatter_tensor_backward_cuda(torch::Tensor in_tensor_grad,
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

/*********************************
Using Raster Order
*********************************/

void scatter_tensor_forward(
torch::Tensor out_tensor,
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1, int H, int W){
CHECK_INPUT(out_tensor);
CHECK_INPUT(in_tensor);
CHECK_INPUT(labels);
CHECK_INPUT(flows_k);
scatter_tensor_forward_cuda(out_tensor,in_tensor,labels,flows_k,
stride0,stride1,H,W);
}

void scatter_tensor_backward(
torch::Tensor out_tensor_grad,
const torch::Tensor in_tensor_grad,
const torch::Tensor labels,
const torch::Tensor flows_k, int stride0){
CHECK_INPUT(in_tensor_grad);
CHECK_INPUT(out_tensor_grad);
CHECK_INPUT(labels);
CHECK_INPUT(flows_k);
scatter_tensor_backward_cuda(in_tensor_grad,out_tensor_grad,labels,flows_k,stride0);
}

void scatter_labels(
const torch::Tensor flows, const torch::Tensor flows_k,
torch::Tensor labels, torch::Tensor names,
int ws, int wt, int stride0, float stride1, bool full_ws){
CHECK_INPUT(flows);
CHECK_INPUT(flows_k);
CHECK_INPUT(labels);
CHECK_INPUT(names);
scatter_labels_cuda(flows,flows_k,labels,names,
ws,wt,stride0,stride1,full_ws);
}

void scatter_int_forward(
const torch::Tensor vid, const torch::Tensor weights,
const torch::Tensor inds, const torch::Tensor labels,
Expand Down Expand Up @@ -200,12 +151,12 @@ void scatter_int_forward(

// python bindings
void init_scatter(py::module &m){
m.def("scatter_labels", &scatter_labels,
"Scatter Labels");
m.def("scatter_tensor_forward", &scatter_tensor_forward,
"Scatter Tensor");
m.def("scatter_tensor_backward", &scatter_tensor_backward,
"Scatter Tensor");
// m.def("scatter_labels", &scatter_labels,
// "Scatter Labels");
// m.def("scatter_tensor_forward", &scatter_tensor_forward,
// "Scatter Tensor");
// m.def("scatter_tensor_backward", &scatter_tensor_backward,
// "Scatter Tensor");
m.def("scatter_int_forward", &scatter_int_forward,
"Scatter Forward with Int Indexing");
// m.def("scatter_int_backward",&scatter_int_backward,
Expand Down
23 changes: 13 additions & 10 deletions lib/csrc/agg/scatter_add_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,24 @@ __global__ void scatter_add_forward_kernel(
ref_p[2] = ref[2]+dilation*(pj + patch_offset);
check_bounds(valid, ref_p, T, inH, inW);
if (not valid){ continue; }

// -- increment legal refs --
if ((ref[0]==0) and (ibatch==0) and (ihead==0) and (ki==0)){
atomicAdd(&counts[ref_p[1]][ref_p[2]],1);
}


// -- non-local pixel index --
nl_p[0] = nl[0];
nl_p[1] = nl[1]+dilation*(pi + patch_offset);
nl_p[1] = reflect_bounds ? bounds(nl_p[1],inH) : nl_p[1];
nl_p[1] = reflect_bounds ? bounds(nl_p[1],outH) : nl_p[1];
nl_p[2] = nl[2]+dilation*(pj + patch_offset);
nl_p[2] = reflect_bounds ? bounds(nl_p[2],inW) : nl_p[2];
check_bounds(valid, nl_p, T, inH, inW);
nl_p[2] = reflect_bounds ? bounds(nl_p[2],outW) : nl_p[2];
check_bounds(valid, nl_p, T, outH, outW);
if (not valid){ continue; }


// -- increment legal refs --
// if ((ref[0]==0) and (ibatch==0) and (ihead==0) and (ki==0)){
// atomicAdd(&counts[nl_p[1]][nl_p[2]],1);
// }
if ((ref[0]==0) and (ibatch==0) and (ihead==0)){
atomicAdd(&counts[nl_p[1]][nl_p[2]],1);
}

// -- iterate over loop --
for(int pk = 0; pk < pt; pk++){

Expand Down
57 changes: 57 additions & 0 deletions lib/csrc/graph_opts/gather_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

#include <torch/extension.h>
#include <vector>

// gathering

void gather_tensor_forward_cuda(torch::Tensor out_tensor,
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1, int H, int W);

void gather_tensor_backward_cuda(torch::Tensor in_tensor_grad,
const torch::Tensor out_tensor_grad,
const torch::Tensor labels,
const torch::Tensor flows_k, int stride0);

// C++ interface

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// Forward/Backward

void gather_tensor_forward(
torch::Tensor out_tensor,
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1, int H, int W){
CHECK_INPUT(out_tensor);
CHECK_INPUT(in_tensor);
CHECK_INPUT(labels);
CHECK_INPUT(flows_k);
gather_tensor_forward_cuda(out_tensor,in_tensor,labels,flows_k,
stride0,stride1,H,W);
}

void gather_tensor_backward(
torch::Tensor out_tensor_grad,
const torch::Tensor in_tensor_grad,
const torch::Tensor labels,
const torch::Tensor flows_k, int stride0){
CHECK_INPUT(in_tensor_grad);
CHECK_INPUT(out_tensor_grad);
CHECK_INPUT(labels);
CHECK_INPUT(flows_k);
gather_tensor_backward_cuda(in_tensor_grad,out_tensor_grad,labels,flows_k,stride0);
}

// python bindings
void init_gather_tensor(py::module &m){
m.def("gather_tensor_forward", &gather_tensor_forward,"Gather Labels");
m.def("gather_tensor_backward", &gather_tensor_backward,"Gather Labels");
}

Loading

0 comments on commit b8b5ca3

Please sign in to comment.