From f92860fd0e6d995b149f0a26f9c4914532bd794e Mon Sep 17 00:00:00 2001 From: gauenk Date: Sun, 29 Oct 2023 13:14:21 -0400 Subject: [PATCH] added refine skel --- lib/csrc/search/paired_refine.cpp | 147 +++++ lib/csrc/search/paired_refine_kernel.cu | 772 ++++++++++++++++++++++++ lib/stnls/search/paired_refine.py | 327 ++++++++++ 3 files changed, 1246 insertions(+) create mode 100644 lib/csrc/search/paired_refine.cpp create mode 100644 lib/csrc/search/paired_refine_kernel.cu create mode 100644 lib/stnls/search/paired_refine.py diff --git a/lib/csrc/search/paired_refine.cpp b/lib/csrc/search/paired_refine.cpp new file mode 100644 index 0000000..cdfdfd1 --- /dev/null +++ b/lib/csrc/search/paired_refine.cpp @@ -0,0 +1,147 @@ +#include + +#include + +// CUDA forward declarations + +void paired_refine_int_forward_cuda( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, torch::Tensor dists, torch::Tensor inds, + int ps, int k, int stride0, int stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type); + +void paired_refine_bilin2d_forward_cuda( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + torch::Tensor dists, torch::Tensor inds, torch::Tensor kselect, + int ps, int k, int stride0, float stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type); + +void paired_refine_int_backward_cuda( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor grad_dists, const torch::Tensor inds, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type); + +void paired_refine_bilin2d_backward_cuda( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + torch::Tensor grad_flow, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + const torch::Tensor grad_dists, const torch::Tensor grad_inds, + const torch::Tensor inds, const torch::Tensor kselect, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type); + + +// C++ interface + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + +void paired_refine_int_forward( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, torch::Tensor dists, torch::Tensor inds, + int ps, int k, int stride0, int stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type){ + CHECK_INPUT(frame0); + CHECK_INPUT(frame1); + CHECK_INPUT(flow); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + paired_refine_int_forward_cuda(frame0, frame1, flow, dists, inds, + ps, k, stride0, stride1, dilation, + reflect_bounds, full_ws, patch_offset, dist_type); + +} + +void paired_refine_bilin2d_forward( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + torch::Tensor dists, torch::Tensor inds, torch::Tensor kselect, + int ps, int k, int stride0, float stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type){ + CHECK_INPUT(frame0); + CHECK_INPUT(frame1); + CHECK_INPUT(flow); + CHECK_INPUT(kselect); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + paired_refine_bilin2d_forward_cuda(frame0, frame1, flow, dists, inds, kselect, + ps, k, stride0, stride1, dilation, + reflect_bounds, full_ws, patch_offset, dist_type); +} + +void paired_refine_int_backward( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor grad_dists, const torch::Tensor inds, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type) { + + // -- validate -- + CHECK_INPUT(grad_frame0); + CHECK_INPUT(grad_frame1); + CHECK_INPUT(frame0); + CHECK_INPUT(frame1); + CHECK_INPUT(grad_dists); + CHECK_INPUT(inds); + + // -- refine -- + paired_refine_int_backward_cuda( + grad_frame0, grad_frame1, + frame0, frame1, + grad_dists, inds, + stride0, ps, dilation, reflect_bounds, + patch_offset, dist_type); + +} + + +void paired_refine_bilin2d_backward( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + torch::Tensor grad_flow, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + const torch::Tensor grad_dists, const torch::Tensor grad_inds, + const torch::Tensor inds, const torch::Tensor kselect, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type) { + + // -- validate -- + CHECK_INPUT(grad_frame0); + CHECK_INPUT(grad_frame1); + CHECK_INPUT(frame0); + CHECK_INPUT(frame1); + CHECK_INPUT(flow); + CHECK_INPUT(grad_dists); + CHECK_INPUT(grad_inds); + CHECK_INPUT(inds); + CHECK_INPUT(kselect); + + paired_refine_bilin2d_backward_cuda( + grad_frame0, grad_frame1, grad_flow, + frame0, frame1, flow, + grad_dists, grad_inds, inds, kselect, + stride0, ps, dilation, reflect_bounds, + patch_offset, dist_type); + +} + + +// python bindings +void init_paired_refine(py::module &m){ + m.def("paired_refine_int_forward", &paired_refine_int_forward, + "Refine Forward with Heads (CUDA)"); + m.def("paired_refine_bilin2d_forward", &paired_refine_bilin2d_forward, + "Refine Forward with Heads (CUDA)"); + m.def("paired_refine_int_backward", &paired_refine_int_backward, + "Refine Backward with Heads (CUDA)"); + m.def("paired_refine_bilin2d_backward", &paired_refine_bilin2d_backward, + "Refine Backward with Heads (CUDA)"); + +} + diff --git a/lib/csrc/search/paired_refine_kernel.cu b/lib/csrc/search/paired_refine_kernel.cu new file mode 100644 index 0000000..55207e7 --- /dev/null +++ b/lib/csrc/search/paired_refine_kernel.cu @@ -0,0 +1,772 @@ + +// #include +#include +#include +#include +#include +#include +#include "paired_details.cu" + +using namespace at; + + +/**************************** + + Forward Pass + +****************************/ + +template +__global__ void paired_refine_int_forward_kernel( + const torch::PackedTensorAccessor32 frame0, + const torch::PackedTensorAccessor32 frame1, + const torch::PackedTensorAccessor32 flow, + torch::PackedTensorAccessor32 dists, + torch::PackedTensorAccessor32 inds, + int ws, int ps, int stride0, int stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, + int q_per_thread, int ws_per_thread){ + + // -- unpack shape -- + int B = frame0.size(0); + int HD_frame = frame0.size(1); + int HD_flow = flow.size(1); + int HD_search = inds.size(1); + int C = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + int Q = dists.size(2); + int HD = max(HD_frame,HD_flow); + + // -- invalid constant -- + scalar_t invalid = (scalar_t)__int_as_float(0x7f800000); + if(DIST_TYPE == 0){ // prod + invalid = -invalid; + } + + + // -- search region offsets -- + // int psHalf = (ps)/2; + int wsHalf = (ws-1)/2; + // int wsHalf_w = (ws_w)/2; + // int adj = use_adj ? psHalf : 0; + int wsOff_h,wsOff_w; + // int wsMax_h = stride1*(ws_h-1-wsHalf_h); + // int wsMax_w = stride1*(ws_w-1-wsHalf_w); + + // -- cuda index -- + int ihead = blockIdx.z/B; + int ibatch = (blockIdx.z-ihead*B) % B; + int si = blockIdx.y; + int ihead_fr = ihead % HD_frame; + int ihead_fl = ihead % HD_flow; + int ihead_sr = ihead % HD_search; + int q_start = blockIdx.x*q_per_thread; + int qi,ws_i,ws_j; + + // decls + int ref_patch[2]; + int prop_patch[2]; + int frame_anchor[2]; + int ref_pix[2]; + int prop_pix[2]; + bool valid; + bool valid_ref_patch,valid_prop_patch; + bool valid_ref[3]; + bool valid_prop[3]; + + // -- indexing -- + scalar_t dist; + + for (int q_index = 0; q_index < q_per_thread; q_index++){ + + + //--------------------------- + // Anchor Pixel + //--------------------------- + + // -- block start -- + qi = q_start + q_index; + if (qi >= Q){ continue; } + + // -- pixel location from query index -- + get_pixel_loc_2d(ref_patch,qi,stride0,H,W); + int nh = ref_patch[0]/stride0; + int nw = ref_patch[1]/stride0; + check_bounds_2d(valid_ref_patch,ref_patch,H,W); + + // -- assign to reference -- + frame_anchor[0] = ref_patch[0] + flow[ibatch][ihead_fl][si][nh][nw][0]; + frame_anchor[1] = ref_patch[1] + flow[ibatch][ihead_fl][si][nh][nw][1]; + frame_anchor[0] = bounds(frame_anchor[0],H); + frame_anchor[1] = bounds(frame_anchor[1],W); + + // -- search region offsets -- + set_search_offsets(wsOff_h, wsOff_w, + frame_anchor[0], frame_anchor[1], stride1, + wsHalf, ws, H, W, full_ws); + + // --------------------------------------- + // spatial searching + // --------------------------------------- + + // -- search across space -- + for (int _xi = 0; _xi < ws_per_thread; _xi++){ + ws_i = threadIdx.x + blockDim.x*_xi; + if (ws_i >= ws){ continue; } + for (int _yi = 0; _yi < ws_per_thread; _yi++){ + ws_j = threadIdx.y + blockDim.y*_yi; + if (ws_j >= ws){ continue; } + + // -- compute proposed location -- + prop_patch[0] = frame_anchor[0] + stride1 * (ws_i - wsOff_h); + prop_patch[1] = frame_anchor[1] + stride1 * (ws_j - wsOff_w); + check_bounds_2d(valid_prop_patch,prop_patch,H,W); + valid = valid_ref_patch && valid_prop_patch; + + // -- init dist -- + dist = 0; + + // -- compute patch difference -- + if (valid){ + + compute_dist_2d(dist, + frame0[ibatch][ihead_fr],frame1[ibatch][ihead_fr], + ref_patch, prop_patch, + ref_pix, prop_pix, valid_ref, valid_prop, + ps,dilation,reflect_bounds, + patch_offset,invalid,C,H,W); + + } + + // -- assignent -- + if (!valid){ dist = invalid; } + dists[ibatch][ihead_sr][qi][si][ws_i][ws_j] = dist; + inds[ibatch][ihead_sr][qi][si][ws_i][ws_j][0] = prop_patch[0]-ref_patch[0]; + inds[ibatch][ihead_sr][qi][si][ws_i][ws_j][1] = prop_patch[1]-ref_patch[1]; + + } + } + } +} + +void paired_refine_int_forward_cuda( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, torch::Tensor dists, torch::Tensor inds, + int ps, int k, int stride0, int stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type){ + + // -- derived quantities -- + int B = frame0.size(0); + int HD_frame = frame0.size(1); + int HD_flow = flow.size(1); + int H = frame0.size(3); + int W = frame0.size(4); + int S = flow.size(2); + // int nH0 = (H-1)/stride0+1; + int HD = max(HD_frame,HD_flow); + + // -- threads -- + int nqueries = dists.size(2); + int ws = dists.size(3); + int ws_threads = std::min(ws,25); + int ws_per_thread = ((ws-1)/ws_threads) + 1; + dim3 nthreads(ws_threads,ws_threads); + + // -- nblocks -- + int q_per_thread = 2; + int nquery_blocks = ((nqueries - 1) / q_per_thread) + 1; + dim3 nblocks(nquery_blocks,S,B*HD); + + // -- share -- + // int psHalf = ps/2; + // int adj = use_adj ? psHalf : 0; + // // int patch_offset = adj - psHalf; + // int patch_offset = adj - psHalf; + + // -- viz -- + // fprintf(stdout,"ws_h,ws_w: %d,%d,%d,%d\n",ws_h,ws_w,ws_h_threads,ws_h_per_thread); + // fprintf(stdout,"nquery_blocks,B,HD: %d,%d,%d\n",nquery_blocks,B,HD); + + + // launch kernel + if (dist_type == 0){ + AT_DISPATCH_FLOATING_TYPES(frame0.type(),"paired_refine_int_forward_kernel", ([&] { + paired_refine_int_forward_kernel<<>>( + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, + patch_offset, q_per_thread, ws_per_thread); + })); + }else if(dist_type == 1){ + AT_DISPATCH_FLOATING_TYPES(frame0.type(),"paired_refine_int_forward_kernel", ([&] { + paired_refine_int_forward_kernel<<>>( + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, + patch_offset, q_per_thread, ws_per_thread); + })); + }else{ + throw std::invalid_argument("Uknown distance type. Must be 0 (product) or 1 (l2)"); + } +} + + +/********************************** + + Forward Pass (Bilin2d) + +**********************************/ + +template +__global__ void paired_refine_bilin2d_forward_kernel( + const torch::PackedTensorAccessor32 frame0, + const torch::PackedTensorAccessor32 frame1, + const torch::PackedTensorAccessor32 flow, + torch::PackedTensorAccessor32 dists, + torch::PackedTensorAccessor32 inds, + torch::PackedTensorAccessor32 kselect, + int ws, int ps, int stride0, float _stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, + int q_per_thread, int ws_per_thread){ + + // -- unpack shape -- + int B = frame0.size(0); + int HD_frame = frame0.size(1); + int HD_flow = flow.size(1); + int HD_search = inds.size(1); + int C = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + int Q = dists.size(2); + int HD = max(HD_frame,HD_flow); + scalar_t stride1 = static_cast(_stride1); + + + // -- invalid constant -- + scalar_t invalid = (scalar_t)__int_as_float(0x7f800000); + if(DIST_TYPE == 0){ // prod + invalid = -invalid; + } + + // -- search region offsets -- + // int psHalf = (ps)/2; + // int wsHalf_h = (ws_h)/2; + // int wsHalf_w = (ws_w)/2; + // int wsMax_h = stride1*(ws_h-1-wsHalf_h); + // int wsMax_w = stride1*(ws_w-1-wsHalf_w); + // int adj = use_adj ? psHalf : 0; + + // int wsHalf_h = (ws_h-1)/2; + // int wsHalf_w = (ws_w-1)/2; + // int wsOff_h,wsOff_w; + scalar_t wsHalf = trunc((ws-1)/2); + scalar_t wsOff_h,wsOff_w; + + // -- cuda index -- + int ihead = blockIdx.z/B; + int ibatch = (blockIdx.z-ihead*B) % B; + int si = blockIdx.y; + // int ibatch = blockIdx.y; + // int ihead = blockIdx.z; + int ihead_fr = ihead % HD_frame; + int ihead_fl = ihead % HD_flow; + int ihead_sr = ihead % HD_search; + int q_start = blockIdx.x*q_per_thread; + int qi,ws_i,ws_j; + + // decls + int ref_patch[2]; + scalar_t prop_patch[2]; + scalar_t frame_anchor[2]; + int ref_pix[2]; + scalar_t prop_pix[2]; + // int prop_i[2]; + bool valid; + bool valid_ref_patch,valid_prop_patch; + bool valid_ref[3]; + bool valid_prop[3]; + + // -- indexing -- + scalar_t dist,pix0,pix1; + + for (int q_index = 0; q_index < q_per_thread; q_index++){ + + + //--------------------------- + // Anchor Pixel + //--------------------------- + + // -- block start -- + qi = q_start + q_index; + if (qi >= Q){ continue; } + + // -- pixel location from query index -- + get_pixel_loc_2d(ref_patch,qi,stride0,H,W); + check_bounds_2d(valid_ref_patch,ref_patch,H,W); + int nh = ref_patch[0]/stride0; + int nw = ref_patch[1]/stride0; + + // -- compute frame offsets with flow -- + frame_anchor[0] = ref_patch[0] + flow[ibatch][ihead_fl][si][nh][nw][0]; + frame_anchor[1] = ref_patch[1] + flow[ibatch][ihead_fl][si][nh][nw][1]; + // frame_anchor[0] = ref_patch[0]+flow[ibatch][ihead_fl][1][nh][nw]; + // frame_anchor[1] = ref_patch[1]+flow[ibatch][ihead_fl][0][nh][nw]; + frame_anchor[0] = bounds(frame_anchor[0],H); + frame_anchor[1] = bounds(frame_anchor[1],W); + + // -- search region offsets -- + set_search_offsets(wsOff_h, wsOff_w, + frame_anchor[0], frame_anchor[1], stride1, + wsHalf, ws, H, W, full_ws); + + // --------------------------------------- + // spatial searching + // --------------------------------------- + + // -- search across space -- + for (int _xi = 0; _xi < ws_per_thread; _xi++){ + ws_i = threadIdx.x + blockDim.x*_xi; + if (ws_i >= ws){ continue; } + for (int _yi = 0; _yi < ws_per_thread; _yi++){ + ws_j = threadIdx.y + blockDim.y*_yi; + if (ws_j >= ws){ continue; } + + // -- compute proposed location -- + prop_patch[0] = frame_anchor[0] + stride1 * (ws_i - wsOff_h); + prop_patch[1] = frame_anchor[1] + stride1 * (ws_j - wsOff_w); + check_bounds_2d(valid_prop_patch,prop_patch,H,W); + valid = valid_ref_patch && valid_prop_patch; + + + // -- init dist -- + dist = 0; + // Z = 0; + + // -- compute patch difference -- + if (valid){ + compute_dist_bilin2d_2d(dist, + frame0[ibatch][ihead_fr],frame1[ibatch][ihead_fr], + ref_patch, prop_patch, ref_pix, prop_pix,// prop_i, + valid_ref, valid_prop, ps,dilation,reflect_bounds, + patch_offset,invalid,C,H,W); + // dist /= Z; + } + + + // -- assignent -- + if (!valid){ dist = invalid; } + dists[ibatch][ihead_sr][qi][si][ws_i][ws_j] = dist; + inds[ibatch][ihead_sr][qi][si][ws_i][ws_j][0] = prop_patch[0]-ref_patch[0]; + inds[ibatch][ihead_sr][qi][si][ws_i][ws_j][1] = prop_patch[1]-ref_patch[1]; + kselect[ibatch][ihead][qi][si][ws_i][ws_j] = si; + + } + } + } +} + +void paired_refine_bilin2d_forward_cuda( + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + torch::Tensor dists, torch::Tensor inds, torch::Tensor kselect, + int ps, int k, int stride0, float stride1, int dilation, + bool reflect_bounds, bool full_ws, int patch_offset, int dist_type){ + + // -- derived quantities -- + int B = frame0.size(0); + int HD_frame = frame0.size(1); + int HD_flow = flow.size(1); + int H = frame0.size(3); + int W = frame0.size(4); + // int nH0 = (H-1)/stride0+1; + int HD = max(HD_frame,HD_flow); + int S = flow.size(2); + + // -- threads -- + int nqueries = dists.size(2); + int ws = dists.size(4); + int ws_threads = std::min(ws,15); + int ws_per_thread = ((ws-1)/ws_threads) + 1; + dim3 nthreads(ws_threads,ws_threads); + + // -- nblocks -- + int q_per_thread = 2; + int nquery_blocks = ((nqueries - 1) / q_per_thread) + 1; + dim3 nblocks(nquery_blocks,S,B*HD); + + // -- share -- + // int psHalf = ps/2; + // int adj = use_adj ? psHalf : 0; + // // int patch_offset = adj - psHalf; + // int patch_offset = adj - psHalf; + + // -- viz -- + // fprintf(stdout,"ws_h,ws_w: %d,%d\n",ws_h,ws_w); + // fprintf(stdout,"nquery_blocks,B,HD: %d,%d,%d\n",nquery_blocks,B,HD); + + // launch kernel + if (dist_type == 0){ + AT_DISPATCH_FLOATING_TYPES(frame0.type(), + "paired_refine_bilin2d_forward_kernel", ([&] { + paired_refine_bilin2d_forward_kernel<<>>( + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + kselect.packed_accessor32(), + ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, + patch_offset, q_per_thread, ws_per_thread); + })); + }else if(dist_type == 1){ + AT_DISPATCH_FLOATING_TYPES(frame0.type(), + "paired_refine_bilin2d_forward_kernel", ([&] { + paired_refine_bilin2d_forward_kernel<<>>( + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + kselect.packed_accessor32(), + ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, + patch_offset, q_per_thread, ws_per_thread); + })); + }else{ + throw std::invalid_argument("Uknown distance type. Must be 0 (product) or 1 (l2)"); + } +} + + +/**************************** + + Backward Pass + +****************************/ + +template +__global__ void paired_refine_int_backward_kernel( + torch::PackedTensorAccessor32 grad_frame0, + torch::PackedTensorAccessor32 grad_frame1, + const torch::PackedTensorAccessor32 frame0, + const torch::PackedTensorAccessor32 frame1, + const torch::PackedTensorAccessor32 grad_dists, + const torch::PackedTensorAccessor32 inds, + int stride0, int ps, int dilation, int patch_offset, + bool reflect_bounds, int ftrs_per_thread) { + + // -- shape -- + int nbatch = grad_dists.size(0); + int Q = grad_dists.size(2); + int K = grad_dists.size(3); + int HD_frame = frame0.size(1); + int HD_flow = grad_dists.size(1); + int F = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + int HD = max(HD_frame,HD_flow); + + // -- fwd decl registers -- + int ref_patch[2]; + int prop_patch[2]; + int ref[2]; + int prop[2]; + bool valid_ref[3]; + bool valid_prop[3]; + bool valid; + scalar_t weight,pix0,pix1,pix; + // int center_offsets[4] = {off_H0,off_H1,off_W0,off_W1}; + + + // -- location to fill -- + int qi = blockIdx.x*blockDim.x+threadIdx.x; + int ki = blockIdx.y*blockDim.y+threadIdx.y; + int ihead = blockIdx.z/nbatch; + int ihead_fr = ihead % HD_frame; + int ihead_fl = ihead % HD_flow; + int ibatch = (blockIdx.z-ihead*nbatch) % nbatch; + + // -- feature chunk -- + int ftr_start = threadIdx.z * ftrs_per_thread; + int ftr_end = min(F,ftr_start + ftrs_per_thread); + + // -- each region -- + if ((qi < Q) && (ki < K)){ + + // -- pixel location from query index -- + get_pixel_loc_2d(ref_patch,qi,stride0,H,W); + + // -- proposed location -- + prop_patch[0] = ref_patch[0] + inds[ibatch][ihead_fl][qi][ki][0]; + prop_patch[1] = ref_patch[1] + inds[ibatch][ihead_fl][qi][ki][1]; + prop_patch[0] = bounds(prop_patch[0],H); + prop_patch[1] = bounds(prop_patch[1],W); + weight = grad_dists[ibatch][ihead_fl][qi][ki]; + + // -- update patch -- + update_bwd_patch_2d( + grad_frame0[ibatch][ihead_fr], + grad_frame1[ibatch][ihead_fr], + frame0[ibatch][ihead_fr], + frame1[ibatch][ihead_fr], + weight,ref_patch,prop_patch, + ps,dilation,reflect_bounds, + patch_offset,ftr_start,ftr_end, + ref,prop,valid_ref,valid_prop,valid, + H,W,pix0,pix1); + + } +} + +void paired_refine_int_backward_cuda( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor grad_dists, const torch::Tensor inds, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type) { + + + // -- unpack -- + int B = frame0.size(0); + int HD_frame = frame0.size(1); + int HD_flow = grad_dists.size(1); + int F = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + int HD = max(HD_frame,HD_flow); + int nqueries = inds.size(2); + int K = inds.size(3); + int BHD = B*HD; + + // -- launch parameters -- + int nbatch = grad_dists.size(0); + int nq = grad_dists.size(2); + int k = grad_dists.size(3); + int ftr_threads = min(1,F); + dim3 threadsPerBlock(128,4,ftr_threads); + dim3 blocksPerGrid(1, 1, nbatch*HD); + blocksPerGrid.x = ceil(double(nq)/double(threadsPerBlock.x)); + blocksPerGrid.y = ceil(double(k)/double(threadsPerBlock.y)); + int ftrs_per_thread = (F-1)/ftr_threads+1; + + // -- launch kernel -- + if (dist_type == 0){ // prod + AT_DISPATCH_FLOATING_TYPES(frame0.type(),"paired_refine_backward_kernel", ([&] { + paired_refine_int_backward_kernel<<>>( + grad_frame0.packed_accessor32(), + grad_frame1.packed_accessor32(), + frame0.packed_accessor32(), + frame1.packed_accessor32(), + grad_dists.packed_accessor32(), + inds.packed_accessor32(), + stride0, ps, dilation, patch_offset, reflect_bounds, + ftrs_per_thread); + })); + }else if (dist_type == 1){ // l2 + AT_DISPATCH_FLOATING_TYPES(frame0.type(),"paired_refine_backward_kernel", ([&] { + paired_refine_int_backward_kernel<<>>( + grad_frame0.packed_accessor32(), + grad_frame1.packed_accessor32(), + frame0.packed_accessor32(), + frame1.packed_accessor32(), + grad_dists.packed_accessor32(), + inds.packed_accessor32(), + stride0, ps, dilation, patch_offset, reflect_bounds, + ftrs_per_thread); + })); + }else{ + throw std::invalid_argument("Uknown distance type. Must be 0 (product) or 1 (l2)"); } + + +} + + + +/**************************** + + Backward Bilinear-2d + +****************************/ + +template +__global__ void paired_refine_bilin2d_backward_kernel( + torch::PackedTensorAccessor32 grad_frame0, + torch::PackedTensorAccessor32 grad_frame1, + torch::PackedTensorAccessor32 grad_flow, + const torch::PackedTensorAccessor32 frame0, + const torch::PackedTensorAccessor32 frame1, + const torch::PackedTensorAccessor32 flow, + const torch::PackedTensorAccessor32 grad_dists, + const torch::PackedTensorAccessor32 grad_inds, + const torch::PackedTensorAccessor32 inds, + const torch::PackedTensorAccessor32 kselect, + int stride0, int ps, int dilation, int patch_offset, bool reflect_bounds) { + + // -- shape -- + int nbatch = grad_dists.size(0); + int Q = grad_dists.size(2); + int K = grad_dists.size(3); + int HD_frame = frame0.size(1); + int HD_flow = grad_flow.size(1); + int HD_search = inds.size(1); + int F = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + int HD = max(HD_frame,HD_flow); + + // -- fwd decl registers -- + int ref_patch[2]; + scalar_t prop_patch[2]; + // int ref[2]; + // scalar_t prop[2]; + // int prop_i[2]; + bool valid_ref[3]; + bool valid_prop[3]; + bool valid; + scalar_t weight; + scalar_t iweight[2]; + // int center_offsets[4] = {off_H0,off_H1,off_W0,off_W1}; + + // -- location to fill -- + int qi = blockIdx.x*blockDim.x+threadIdx.x; + int ki = blockIdx.y*blockDim.y+threadIdx.y; + int ihead = blockIdx.z/nbatch; + int ihead_fr = ihead % HD_frame; + int ihead_fl = ihead % HD_flow; + int ihead_sr = ihead % HD_search; + int ibatch = (blockIdx.z-ihead*nbatch); + + // -- feature chunk -- + // int ftr_start = 0;//threadIdx.z * ftrs_per_thread; + // int ftr_end = F;//min(F,ftr_start + ftrs_per_thread); + + // -- each region -- + if ((qi < Q) && (ki < K)){ + + // -- pixel location from query index -- + get_pixel_loc_2d(ref_patch,qi,stride0,H,W); + int nh = ref_patch[0]/stride0; + int nw = ref_patch[1]/stride0; + + // -- accumulate optical flow update -- + scalar_t acc_dFlows[8]; + #pragma unroll + for (int _idx=0; _idx < 8; _idx++){ + acc_dFlows[_idx] = static_cast(0); + } + + // -- proposed location -- + prop_patch[0] = ref_patch[0] + inds[ibatch][ihead_sr][qi][ki][0]; + prop_patch[1] = ref_patch[1] + inds[ibatch][ihead_sr][qi][ki][1]; + prop_patch[0] = bounds(prop_patch[0],H); + prop_patch[1] = bounds(prop_patch[1],W); + + weight = grad_dists[ibatch][ihead_sr][qi][ki]; + iweight[0] = grad_inds[ibatch][ihead_sr][qi][ki][0]; + iweight[1] = grad_inds[ibatch][ihead_sr][qi][ki][1]; + int kj = kselect[ibatch][ihead][qi][ki]; + + // -- update frames -- + update_bwd_bilin2d_patch_2d( + grad_frame0[ibatch][ihead_fr],grad_frame1[ibatch][ihead_fr], + frame0[ibatch][ihead_fr],frame1[ibatch][ihead_fr], + acc_dFlows,weight,ref_patch,prop_patch, + ps,dilation,reflect_bounds,patch_offset, + valid_ref,valid_prop,valid,H,W); + + + // -- update grad_flow from grad_dists,vid0,vid1 -- + scalar_t hi = ref_patch[0] + flow[ibatch][ihead_fl][qi][kj][0]; + scalar_t wi = ref_patch[1] + flow[ibatch][ihead_fl][qi][kj][1]; + int signH = ((hi >= 0) and (hi <= (H-1))) ? 1 : -1; + int signW = ((wi >= 0) and (wi <= (W-1))) ? 1 : -1; + bwd_flow_assign(acc_dFlows,nh,nw,signH,signW,grad_flow[ibatch][ihead_fl]); + + // -- update flows -- + atomicAdd(&(grad_flow[ibatch][ihead_fl][0][nh][nw]),signW*iweight[1]); + atomicAdd(&(grad_flow[ibatch][ihead_fl][1][nh][nw]),signH*iweight[0]); + + } +} + +void paired_refine_bilin2d_backward_cuda( + torch::Tensor grad_frame0, torch::Tensor grad_frame1, + torch::Tensor grad_flow, + const torch::Tensor frame0, const torch::Tensor frame1, + const torch::Tensor flow, + const torch::Tensor grad_dists, const torch::Tensor grad_inds, + const torch::Tensor inds, const torch::Tensor kselect, + int stride0, int ps, int dilation, bool reflect_bounds, + int patch_offset, int dist_type) { + + // -- unpack -- + int HD_frame = frame0.size(1); + int HD_flow = grad_dists.size(1); + int F = frame0.size(2); + int H = frame0.size(3); + int W = frame0.size(4); + // int K = inds.size(3); + // assert(pt == 1); + int HD = max(HD_frame,HD_flow); + + // -- launch parameters -- + int B = grad_dists.size(0); + int Q = grad_dists.size(2); + int K = grad_dists.size(3); + dim3 threadsPerBlock(288,2); + dim3 blocksPerGrid(1, 1, B*HD); + blocksPerGrid.x = ceil(double(Q)/double(threadsPerBlock.x)); + blocksPerGrid.y = ceil(double(K)/double(threadsPerBlock.y)); + + // -- shared -- + // int psHalf = ps/2; + // int adj = use_adj ? psHalf : 0; + // int patch_offset = adj - psHalf; + // int patch_offset = psHalf - adj; + + // -- launch kernel -- + if (dist_type == 0){ // prod + AT_DISPATCH_FLOATING_TYPES(frame0.type(), + "paired_refine_bilin2d_backward_kernel", ([&] { + paired_refine_bilin2d_backward_kernel<<>>( + grad_frame0.packed_accessor32(), + grad_frame1.packed_accessor32(), + grad_flow.packed_accessor32(), + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + grad_dists.packed_accessor32(), + grad_inds.packed_accessor32(), + inds.packed_accessor32(), + kselect.packed_accessor32(), + stride0, ps, dilation, patch_offset, reflect_bounds); + })); + }else if (dist_type == 1){ // l2 + AT_DISPATCH_FLOATING_TYPES(frame0.type(), + "paired_refine_bilin2d_backward_kernel", ([&] { + paired_refine_bilin2d_backward_kernel<<>>( + grad_frame0.packed_accessor32(), + grad_frame1.packed_accessor32(), + grad_flow.packed_accessor32(), + frame0.packed_accessor32(), + frame1.packed_accessor32(), + flow.packed_accessor32(), + grad_dists.packed_accessor32(), + grad_inds.packed_accessor32(), + inds.packed_accessor32(), + kselect.packed_accessor32(), + stride0, ps, dilation, patch_offset, reflect_bounds); + })); + }else{ + throw std::invalid_argument("Uknown distance type. Must be 0 (product) or 1 (l2)"); } + + +} + + diff --git a/lib/stnls/search/paired_refine.py b/lib/stnls/search/paired_refine.py new file mode 100644 index 0000000..8ef002e --- /dev/null +++ b/lib/stnls/search/paired_refine.py @@ -0,0 +1,327 @@ + +# -- python -- +import torch as th +import numpy as np +from einops import rearrange + +# -- cpp cuda kernel -- +import stnls_cuda + +# -- package -- +import stnls + +# -- api -- +from stnls.utils import extract_pairs + +# -- local -- +from .utils import shape_frames,allocate_pair_2d,dist_type_select,allocate_vid +from .utils import get_ctx_shell,ensure_flow_shape +from .shared import manage_self,reflect_bounds_warning +from .paired_bwd_impl import paired_refine_backward +from .batching_utils import run_batched,batching_info +from .paired_utils import paired_vids as _paired_vids + + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# Forward Logic +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def paired_forward(batchsize,*args): + qshift,nqueries = 0,-1 + return paired_fwd_main(qshift,nqueries,*args) + +def ensure_flow_dim(flow): + if flow.ndim == 4: + flow = flow[:,None] # add nheads + return flow + +def paired_refine_forward(frame0, frame1, flow, + wr, ws, ps, k, dist_type, + stride0, stride1, dilation, pt, + self_action, reflect_bounds, + full_ws, use_adj, topk_mode, itype): + + # -- unpack -- + device = frame0.device + B,HD_fr,C,H,W = frame0.shape + HD_flow = flow.shape[1] + # print(frame0.shape,flow.shape) + assert flow.ndim == 5 + HD = max(HD_flow,HD_fr) + patch_offset = 0 if use_adj else -(ps//2) + + # -- derived shapes -- + nH0 = (H-1)//stride0+1 + nW0 = (W-1)//stride0+1 + Q = nH0*nW0 + + # -- settings from distance type -- + dist_type_i,descending,idist_val = dist_type_select(dist_type) + + # -- allocate results -- + base_shape = (B,HD,Q,wr,wr) + dists,inds = allocate_pair_2d(base_shape,device,frame0.dtype,idist_val,itype) + # print("inds.shape: ",inds.shape) + + # -- forward -- + if itype == "int": + flow = flow.round() + inds = inds.int() + stride1 = max(1,int(stride1)) + fwd_fxn = stnls_cuda.paired_refine_int_forward + else: + fwd_fxn = stnls_cuda.paired_refine_bilin2d_forward + stride1 = float(stride1) + # print(frame0.shape,flow.shape,dists.shape,inds.shape) + fwd_fxn(frame0, frame1, flow, dists, inds, + ws, ps, k, stride0, stride1, dilation, + reflect_bounds, full_ws, patch_offset, dist_type_i) + + # print(frame0.shape,frame1.shape,flow.shape,inds.shape) + # -- compress search region -- + dists=dists.view(B,HD,Q,-1) + inds=inds.view(B,HD,Q,-1,2) + # th.cuda.synchronize() + + # -- topk -- + assert self_action in [None,"anchor","anchor_each"] + anchor_self = False if self_action is None else "anchor" in self_action + if topk_mode == "all": + dim = 3 + dists=dists.view(B,HD,Q,Ks*wr*wr) + inds=inds.view(B,HD,Q,Ks*wr*wr,3) + dists,inds,order = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, + descending=descending,unique=False, + return_order=True) + if kselect.ndim > 1: + # print("kselect.shape: ",kselect.shape,order.shape) + kselect = kselect.view(B,HD,Q,Ks*wr*wr) if not(kselect is None) else kselect + # print("kselect.shape: ",kselect.shape,order.shape) + kselect = stnls.nn.topk_f.apply_topk(kselect,order,dim) + elif topk_mode == "each": + # print(dists.shape,kselect.shape) + dists = rearrange(dists,'... wh ww -> ... (wh ww)') + inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3') + dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self) + if kselect.ndim > 1: + kselect = rearrange(kselect,'... wh ww -> ... (wh ww)') + kselect = kselect[...,:k] # all same across dim + else: + raise ValueError(f"Unknown topk_mode [{topk_mode}]") + + # -- topk -- + if k > 0: + dim = 3 + dists=dists.view(B,HD,Q,ws*ws) + inds=inds.view(B,HD,Q,ws*ws,2) + dists,inds = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, + descending=descending) + + # -- reshape -- + dists=dists.reshape(B,HD,1,nH0,nW0,-1) + inds=inds.reshape(B,HD,1,nH0,nW0,-1,2) + + return dists,inds + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# Pytorch Function +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +class PairedRefineFunction(th.autograd.Function): + + + @staticmethod + def forward(ctx, frame0, frame1, flow, + wr, ws, ps, k, nheads=1, + dist_type="prod", stride0=4, stride1=1, + dilation=1, pt=1, reflect_bounds=True, + full_ws=True, self_action=None, use_adj=False, + normalize_bwd=False, k_agg=-1, topk_mode="each", itype="float"): + + """ + Run the non-local search + + frame0 = [B,T,C,H,W] or [B,HD,T,C,H,W] + ws = search Window Spatial (ws) + """ + + # -- reshape with heads -- + dtype = frame0.dtype + device = frame0.device + ctx.in_ndim = frame0.ndim + frame0,frame1 = shape_frames(nheads,[frame0,frame1]) + # print("frame0.shape: ",frame0.shape) + flow = ensure_flow_dim(flow) + # flow = ensure_flow_shape(flow) + B,HD,F,H,W = frame0.shape + flow = flow.contiguous() + reflect_bounds_warning(reflect_bounds) + + # -- run [optionally batched] forward function -- + dists,inds = paired_refine_forward(frame0, frame1, flow, + ws, ps, k, dist_type, + stride0, stride1, dilation, pt, + self_action, reflect_bounds, full_ws, + use_adj, topk_mode, itype) + + # -- setup ctx -- + dist_type_i = dist_type_select(dist_type)[0] + flow = get_ctx_shell(flow,itype=="int") + ctx.save_for_backward(inds,frame0,frame1,flow) + if itype == "int": + ctx.mark_non_differentiable(inds) + ctx.vid_shape = frame0.shape + ctx_vars = {"stride0":stride0,"stride1":stride1, + "wr":wr,"ps":ps,"pt":pt,"ws":ws,"dil":dilation, + "reflect_bounds":reflect_bounds, + "normalize_bwd":normalize_bwd, + "k_agg":k_agg,"use_adj":use_adj, + "dist_type_i":dist_type_i,"itype":itype} + for name,val in ctx_vars.items(): + setattr(ctx,name,val) + + # -- return -- + return dists,inds + + @staticmethod + def backward(ctx, grad_dists, grad_inds): + grad0,grad1,gflow = paired_refine_backward(ctx, grad_dists, grad_inds) + return grad0,grad1,gflow,None,None,None,None,None,None,None,None,\ + None,None,None,None,None,None,None,None,None,None,None,None,None,\ + None,None,None,None,None,None,None,None,None,None,None,None,None,None + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# Pytorch Module +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + + +class PairedRefine(th.nn.Module): + + def __init__(self, wr, ws, ps, k, nheads=1, + dist_type="l2", stride0=4, stride1=1, + dilation=1, pt=1, reflect_bounds=True, + full_ws=True, self_action=None, use_adj=False, + normalize_bwd=False, k_agg=-1, topk_mode="each", itype="float"): + super().__init__() + + # -- core search params -- + self.wr = wr + self.ws = ws + self.ps = ps + self.k = k + self.nheads = nheads + self.dist_type = dist_type + self.stride0 = stride0 + self.stride1 = stride1 + self.dilation = dilation + self.pt = pt + self.itype = itype + + # -- manage patch and search boundaries -- + self.reflect_bounds = reflect_bounds + self.full_ws = full_ws + self.use_adj = use_adj + self.topk_mode = topk_mode + + # -- special mods to "self" search -- + self.self_action = self_action + + # -- backprop params -- + self.normalize_bwd = normalize_bwd + self.k_agg = k_agg + + + def paired_vids(self, vid0, vid1, flows, wt, skip_self=False): + return _paired_vids(self.forward, vid0, vid1, flows, wt, skip_self) + + def forward(self, frame0, frame1, flow): + assert self.ws > 0,"Must have nonzero spatial search window" + return PairedRefineFunction.apply(frame0,frame1,flow, + self.wr,self.ws,self.ps,self.k, + self.nheads,self.dist_type,self.stride0, + self.stride1,self.dilation,self.pt, + self.reflect_bounds,self.full_ws, + self.self_action,self.use_adj, + self.normalize_bwd,self.k_agg, + self.topk_mode,self.itype) + + def flops(self,T,F,H,W): + return 0 + + # -- unpack -- + ps,pt = self.ps,self.pt + + # -- compute search -- + nrefs_hw = ((H-1)//self.stride0+1) * ((W-1)//self.stride0+1) + nrefs = T * HD * nrefs_hw + nsearch = ws_h * ws_w + flops_per_search = 2 * F * ps * ps * pt + search_flops = nrefs * nsearch * flops_per_search + flops = search_flops + + # -- compute top-k -- + if self.k > 0: + sort_flops = nrefs * (nsearch * np.log(nsearch)) + flops += sort_flops + + return flops + + def radius(self,H,W): + return self.ws + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# [Functional API] stnls.search.paired(...) +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def _apply(frame0, frame1, flow, + wr, ws, ps, k, nheads=1, batchsize=-1, + dist_type="l2", stride0=4, stride1=1, + dilation=1, pt=1, reflect_bounds=True, + full_ws=True,self_action=None, + use_adj=False, normalize_bwd=False, k_agg=-1, + topk_mode="each",itype="float"): + # wrap "new (2018) apply function + # https://discuss.pytorch.org #13845/17 + # cfg = extract_config(kwargs) + fxn = PairedRefineFunction.apply + return fxn(frame0,frame1,flow,wr,ws,ps,k, + nheads,batchsize,dist_type, + stride0,stride1,dilation,pt,reflect_bounds, + full_ws,self_action,use_adj,normalize_bwd,k_agg, + topk_mode,itype) + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# [Python Dict API] stnls.search.init(pydict) +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def extract_config(cfg,restrict=True): + pairs = {"wr":1,"ws":-1,"ps":3,"k":10, + "nheads":1,"dist_type":"l2", + "stride0":1, "stride1":1, "dilation":1, "pt":1, + "reflect_bounds":True, "full_ws":True, + "self_action":None,"use_adj":False, + "normalize_bwd": False, "k_agg":-1, + "topk_mode":"each","itype":"float",} + return extract_pairs(cfg,pairs,restrict=restrict) + +def init(cfg): + cfg = extract_config(cfg,False) + search = PairedRefine(cfg.wr, cfg.ws, cfg.ps, cfg.k, nheads=cfg.nheads, + dist_type=cfg.dist_type, stride0=cfg.stride0, + stride1=cfg.stride1, dilation=cfg.dilation, pt=cfg.pt, + reflect_bounds=cfg.reflect_bounds, + full_ws=cfg.full_ws, self_action=cfg.self_action, + use_adj=cfg.use_adj,normalize_bwd=cfg.normalize_bwd, + k_agg=cfg.k_agg,topk_mode=cfg.topk_mode,itype=cfg.itype) + return search +