Skip to content

Commit

Permalink
refactor round #2; even cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Oct 30, 2023
1 parent f92860f commit f00e315
Show file tree
Hide file tree
Showing 23 changed files with 1,513 additions and 723 deletions.
14 changes: 14 additions & 0 deletions lib/csrc/nn/anchor_self.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ void anchor_self_refine_forward_cuda(
torch::Tensor dists, torch::Tensor inds,
torch::Tensor flows, int stride0, int H, int W);

void anchor_self_paired_forward_cuda(
torch::Tensor dists, torch::Tensor inds,
torch::Tensor flows, int stride0, int H, int W);

// C++ interface

Expand Down Expand Up @@ -68,6 +71,15 @@ void anchor_self_refine_forward(
anchor_self_refine_forward_cuda(dists,inds,flows,stride0,H,W);
}

void anchor_self_paired_forward(
torch::Tensor dists, torch::Tensor inds,
torch::Tensor flows, int stride0, int H, int W){
CHECK_INPUT(dists);
CHECK_INPUT(inds);
CHECK_INPUT(flows);
anchor_self_paired_forward_cuda(dists,inds,flows,stride0,H,W);
}

// python bindings
void init_anchor_self(py::module &m){
m.def("anchor_self", &anchor_self_forward,
Expand All @@ -76,5 +88,7 @@ void init_anchor_self(py::module &m){
"anchor_self (CUDA)");
m.def("anchor_self_refine", &anchor_self_refine_forward,
"anchor_self (CUDA)");
m.def("anchor_self_paired", &anchor_self_paired_forward,
"anchor_self (CUDA)");

}
203 changes: 191 additions & 12 deletions lib/csrc/nn/anchor_self_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,14 @@ __global__ void anchor_self_time_kernel(
if (qi >= Q){ continue; }

// -- unpack pixel locs --
// get_pixel_loc(loc, qi, tmp, stride0, nW, nHW, H,W);
int tmp = qi;
iloc[0] = qi / nHW;
tmp = (tmp - iloc[0]*nHW);
int nH_index = tmp / nW;
iloc[1] = (nH_index*stride0) % H;
tmp = tmp - nH_index*nW;
iloc[2] = ((tmp % nW) * stride0) % W;
get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W);
// int tmp = qi;
// iloc[0] = qi / nHW;
// tmp = (tmp - iloc[0]*nHW);
// int nH_index = tmp / nW;
// iloc[1] = (nH_index*stride0) % H;
// tmp = tmp - nH_index*nW;
// iloc[2] = ((tmp % nW) * stride0) % W;

// -- select time --
int n_hi = iloc[1]/stride0;
Expand All @@ -241,7 +241,6 @@ __global__ void anchor_self_time_kernel(
int t_next = iloc[0] + st_i;
t_next = (t_next > t_max) ? t_max - st_i : t_next;


// -- get anchor index --
loc[0] = t_next - iloc[0];
if (st_i >= st_offset){
Expand Down Expand Up @@ -366,6 +365,175 @@ void anchor_self_time_forward_cuda(
}


/*********************************************************
Anchor Paired Search
*********************************************************/


template <typename scalar_t, typename itype>
__global__ void anchor_self_paired_kernel(
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> dists,
torch::PackedTensorAccessor32<itype,6,torch::RestrictPtrTraits> inds,
torch::PackedTensorAccessor32<itype,6,torch::RestrictPtrTraits> flows,
int stride0, int H, int W, int nHW, int nW, int q_per_thread){

// -- starting qi for thread --
int HD = dists.size(1);
int HD_f = flows.size(1);
int Q = dists.size(2);
int W_t = dists.size(3);
int K = dists.size(4);
int T = flows.size(2);
int bi = blockIdx.y;
int hi = blockIdx.z;
int hi_f = hi % HD_f;
int raster_idx = threadIdx.x + blockDim.x * blockIdx.x;
int qi_thread = raster_idx/W_t;
int gi = (raster_idx - qi_thread*W_t);
qi_thread = q_per_thread*qi_thread;
int self_index = 0;
bool eq_loc;
int iloc[3];
itype loc[3];
itype i_tmp[3];
scalar_t d_tmp;
int qi;
scalar_t delta,dmin_curr;
int min_idx;

// -- for each location --
if (gi >= W_t ){ return; }
for (int qi_ix = 0; qi_ix < q_per_thread; qi_ix++){

// -- current query --
qi = qi_thread + qi_ix;
if (qi >= Q){ continue; }

// -- unpack pixel locs --
get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W);
int n_hi = iloc[1]/stride0;
int n_wi = iloc[2]/stride0;

// -- get anchor index --
auto flows_t = flows[bi][hi_f][gi];
loc[0] = iloc[1] + flows_t[1][n_hi][n_wi];
loc[1] = iloc[2] + flows_t[0][n_hi][n_wi];
loc[0] = bounds(loc[0],H)-iloc[1];
loc[1] = bounds(loc[1],W)-iloc[2];

// -- search for matching index --
min_idx = 0;
dmin_curr = 10000;
for (self_index = 0; self_index < K; self_index++){

delta = 0;
eq_loc = true;
for (int ix=0; ix<2; ix++){
if (is_same_v<itype,int>){
eq_loc = eq_loc && (inds[bi][hi][qi][gi][self_index][ix] == loc[ix]);
}else{
delta += fabs(inds[bi][hi][qi][gi][self_index][ix] - loc[ix]);
}
}
eq_loc = eq_loc && (delta < 1e-8);

if (is_same_v<itype,int>){
if (eq_loc){ min_idx = self_index; break; }
}else{
if (delta < 1e-8){ min_idx = self_index; break; }// break if equal
else if (delta < dmin_curr){ // update min otherwise
min_idx = self_index;
dmin_curr = delta;
}
}

}
assert(min_idx<K);

// -- swap dists --
self_index = min_idx;
d_tmp = dists[bi][hi][qi][gi][0];
dists[bi][hi][qi][gi][0] = dists[bi][hi][qi][gi][self_index];
dists[bi][hi][qi][gi][self_index] = d_tmp;

// -- swap inds --
#pragma unroll
for(int ix=0; ix<2; ix++){
i_tmp[ix] = inds[bi][hi][qi][gi][0][ix];
}
#pragma unroll
for(int ix=0; ix<2; ix++){
inds[bi][hi][qi][gi][0][ix] = loc[ix];
}
#pragma unroll
for(int ix=0; ix<2; ix++){
inds[bi][hi][qi][gi][self_index][ix] = i_tmp[ix];
}

}
}


void anchor_self_paired_forward_cuda(
torch::Tensor dists, torch::Tensor inds,
torch::Tensor flows, int stride0, int H, int W){

// -- unpack --
int B = dists.size(0);
int HD = dists.size(1);
int Q = dists.size(2);
int G = dists.size(3);
int K = dists.size(4);

// -- derivative --
int nH = (H-1)/stride0+1;
int nW = (W-1)/stride0+1;
int nHW = nH*nW;

// -- num 2 run --
int nRun = Q*G;

// -- kernel params --
int q_per_thread = 1;
int _nthreads = 512;
dim3 nthreads(_nthreads);
int _nblocks = (nRun-1)/(_nthreads*q_per_thread)+1;
dim3 nblocks(_nblocks,B,HD);
// fprintf(stdout,"nblocks,nthreads: %d,%d\n",_nblocks,_nthreads);
// fprintf(stdout,"nH,nW,stride0: %d,%d,%d\n",nH,nW,stride0);

// -- launch kernel --
auto itype = get_type(inds);
auto dtype = get_type(dists);
if (itype == torch::kInt32){
AT_DISPATCH_FLOATING_TYPES(dists.type(), "anchor_self_paired_kernel", ([&] {
anchor_self_paired_kernel<scalar_t,int><<<nblocks, nthreads>>>(
dists.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,6,torch::RestrictPtrTraits>(),
flows.packed_accessor32<int,6,torch::RestrictPtrTraits>(),
stride0, H, W, nHW, nW, q_per_thread);
}));
}else if (itype == dtype){
AT_DISPATCH_FLOATING_TYPES(dists.type(), "anchor_self_paired_kernel", ([&] {
anchor_self_paired_kernel<scalar_t,scalar_t><<<nblocks, nthreads>>>(
dists.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
inds.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
flows.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>(),
stride0, H, W, nHW, nW, q_per_thread);
}));

}else{
std::cout << "Must have inds type be int or match dists.\n" << std::endl;
assert(1==0);
}

}



/*********************************************************
Expand Down Expand Up @@ -397,6 +565,7 @@ __global__ void anchor_self_refine_kernel(
qi_thread = q_per_thread*qi_thread;
int self_index = 0;
bool eq_loc;
int iloc[3];
itype loc[3];
itype i_tmp[3];
scalar_t d_tmp;
Expand All @@ -412,10 +581,20 @@ __global__ void anchor_self_refine_kernel(
qi = qi_thread + qi_ix;
if (qi >= Q){ continue; }

// -- unpack pixel locs --
// // -- unpack pixel locs --
// loc[0] = round(flows[bi][hi_f][qi][gi][0]);
// loc[1] = flows[bi][hi_f][qi][gi][1];
// loc[2] = flows[bi][hi_f][qi][gi][2];
get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W);
// int n_hi = iloc[1]/stride0;
// int n_wi = iloc[2]/stride0;

// -- get anchor index --
loc[0] = round(flows[bi][hi_f][qi][gi][0]);
loc[1] = flows[bi][hi_f][qi][gi][1];
loc[2] = flows[bi][hi_f][qi][gi][2];
loc[1] = iloc[1] + flows[bi][hi_f][qi][gi][1];
loc[2] = iloc[2] + flows[bi][hi_f][qi][gi][2];
loc[1] = bounds(loc[1],H)-iloc[1];
loc[2] = bounds(loc[2],W)-iloc[2];

// -- search for matching index --
min_idx = 0;
Expand Down
6 changes: 3 additions & 3 deletions lib/csrc/search/paired_search_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename scalar_t, int DIST_TYPE>
__global__ void paired_search_int_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> frame0,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> frame1,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> flow,
const torch::PackedTensorAccessor32<int,5,torch::RestrictPtrTraits> flow,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> dists,
torch::PackedTensorAccessor32<int,6,torch::RestrictPtrTraits> inds,
int ws, int ps, int stride0, int stride1, int dilation,
Expand Down Expand Up @@ -193,7 +193,7 @@ void paired_search_int_forward_cuda(
paired_search_int_forward_kernel<scalar_t,0><<<nblocks, nthreads>>>(
frame0.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
frame1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
flow.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
flow.packed_accessor32<int,5,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,6,torch::RestrictPtrTraits>(),
ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws,
Expand All @@ -204,7 +204,7 @@ void paired_search_int_forward_cuda(
paired_search_int_forward_kernel<scalar_t,1><<<nblocks, nthreads>>>(
frame0.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
frame1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
flow.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
flow.packed_accessor32<int,5,torch::RestrictPtrTraits>(),
dists.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
inds.packed_accessor32<int,6,torch::RestrictPtrTraits>(),
ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws,
Expand Down
1 change: 1 addition & 0 deletions lib/stnls/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
anchor_self = anchor_self_f.run
anchor_self_time = anchor_self_f.run_time
anchor_self_refine = anchor_self_f.run_refine
anchor_self_paired = anchor_self_f.run_paired
non_local_inds = non_local_inds_f.run
accumulate_flow = accumulate_flow_f.run
extract_search_from_accumulated = accumulate_flow_f.extract_search_from_accumulated
Expand Down
30 changes: 26 additions & 4 deletions lib/stnls/nn/anchor_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,19 @@ def run(dists,inds,stride0,H,W,qstart=0):
def run_refine(dists,inds,flows,stride0,H,W):

# -- view --
B,HD,T,nH,nW,Ks,ws,ws = dists.shape
dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws)
inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3)
HD_f = flows.shape[1]
flows = flows.view(B,HD_f,T*nH*nW,Ks,3)
if dists.ndim == 8:
B,HD,T,nH,nW,Ks,ws,ws = dists.shape
dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws)
inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3)
flows = flows.view(B,HD_f,T*nH*nW,Ks,3)
assert inds.shape[-1] == 3,"Index must be size 3."
# elif dists.ndim == 6:
# B,HD,T,nH,nW,Ks,ws,ws = dists.shape
# dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws)
# inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3)
# flows = flows.view(B,HD_f,T*nH*nW,Ks,3)
# print("dists.shape,inds.shape,flows.shape: ",dists.shape,inds.shape,flows.shape)

# -- run --
stnls_cuda.anchor_self_refine(dists,inds,flows,stride0,H,W)
Expand All @@ -65,7 +73,21 @@ def run_time(dists,inds,flows,wt,stride0,H,W):
d2or3 = inds.shape[-1]
dists = dists.view(B,HD,Q,W_t,ws*ws)
inds = inds.view(B,HD,Q,W_t,ws*ws,d2or3)
assert d2or3 == 3,"Index must be size 3."

# -- run --
stnls_cuda.anchor_self_time(dists,inds,flows,wt,stride0,H,W)

def run_paired(dists,inds,flows,stride0,H,W):

# -- view --
B,HD,Q,G,ws,ws = dists.shape
d2or3 = inds.shape[-1]
dists = dists.view(B,HD,Q,G,ws*ws)
inds = inds.view(B,HD,Q,G,ws*ws,d2or3)
assert d2or3 == 2,"Index must be size 2."

# -- run --
stnls_cuda.anchor_self_paired(dists,inds,flows,stride0,H,W)


Loading

0 comments on commit f00e315

Please sign in to comment.