From 0dc1a46c35fee92b5faa6b457dbc801b8047b21f Mon Sep 17 00:00:00 2001 From: gauenk Date: Sun, 3 Dec 2023 14:21:33 -0500 Subject: [PATCH] updated slic; fixed modded agg; added slic in dev --- lib/csrc/agg/pool.cpp | 158 ++++++++ lib/csrc/agg/pool_int_kernel.cu | 358 +++++++++++++++++++ lib/csrc/agg/wpsum_int_kernel.cu | 9 +- lib/csrc/pybind.cpp | 2 + lib/csrc/search/refinement_bilin2d_kernel.cu | 6 + lib/csrc/search/refinement_int_kernel.cu | 4 + lib/stnls/agg/__init__.py | 3 + lib/stnls/agg/pool.py | 273 ++++++++++++++ lib/stnls/agg/scatter_labels.py | 10 +- lib/stnls/agg/scatter_tensor.py | 4 +- lib/stnls/dev/__init__.py | 1 + lib/stnls/dev/slic/__init__.py | 262 ++++++++++++++ scripts/slic.py | 104 ++++-- setup.py | 3 + 14 files changed, 1150 insertions(+), 47 deletions(-) create mode 100644 lib/csrc/agg/pool.cpp create mode 100644 lib/csrc/agg/pool_int_kernel.cu create mode 100644 lib/stnls/agg/pool.py create mode 100644 lib/stnls/dev/slic/__init__.py diff --git a/lib/csrc/agg/pool.cpp b/lib/csrc/agg/pool.cpp new file mode 100644 index 0000000..f74114a --- /dev/null +++ b/lib/csrc/agg/pool.cpp @@ -0,0 +1,158 @@ +#include + +#include + +// CUDA forward declarations + +/************************************* + + Int Forward + + *************************************/ + +void pool_int_forward_cuda( + torch::Tensor out_vid, torch::Tensor counts, + const torch::Tensor in_vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, + bool reflect_bounds, int patch_offset); + +void pool_int_backward_cuda( + torch::Tensor in_vid_grad, + torch::Tensor dists_grad, + const torch::Tensor out_vid_grad, const torch::Tensor vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset); + +/************************************* + + Bilin2d Forward + + *************************************/ + +// void pool_bilin2d_forward_cuda( +// torch::Tensor out_vid, torch::Tensor counts, +// const torch::Tensor in_vid, +// const torch::Tensor dists, const torch::Tensor inds, +// int ps, int stride0, int pt, int dilation, +// bool reflect_bounds, int patch_offset); + +// void pool_bilin2d_backward_cuda( +// torch::Tensor in_vid_grad, +// torch::Tensor dists_grad, +// torch::Tensor inds_grad, +// const torch::Tensor out_vid_grad, const torch::Tensor vid, +// const torch::Tensor dists, const torch::Tensor inds, +// int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset); + + +// C++ interface + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +/*********************** + + + Int Indexing + + +***********************/ + + +void pool_int_forward( + torch::Tensor out_vid, torch::Tensor counts, + const torch::Tensor in_vid, + const torch::Tensor dists, + const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, + bool reflect_bounds, int patch_offset){ + CHECK_INPUT(out_vid); + CHECK_INPUT(counts); + CHECK_INPUT(in_vid); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + pool_int_forward_cuda(out_vid,counts,in_vid,dists,inds, + ps,stride0,pt,dilation,reflect_bounds,patch_offset); +} + +void pool_int_backward( // "in" and "out" w.r.t. forward pass + torch::Tensor in_vid_grad, torch::Tensor dists_grad, + const torch::Tensor out_vid_grad, const torch::Tensor vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset){ + CHECK_INPUT(in_vid_grad); + CHECK_INPUT(dists_grad); + CHECK_INPUT(out_vid_grad); + CHECK_INPUT(vid); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + pool_int_backward_cuda(in_vid_grad,dists_grad, + out_vid_grad,vid,dists,inds, + ps,stride0,pt,dilation,reflect_bounds,patch_offset); +} + +/*********************** + + + Bilinear2d + + +***********************/ + +void pool_bilin2d_forward( + torch::Tensor out_vid, torch::Tensor counts, + const torch::Tensor in_vid, + const torch::Tensor dists, + const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, + bool reflect_bounds, int patch_offset){ + CHECK_INPUT(out_vid); + CHECK_INPUT(counts); + CHECK_INPUT(in_vid); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + // pool_bilin2d_forward_cuda(out_vid,counts,in_vid,dists,inds, + // ps,stride0,pt,dilation,reflect_bounds,patch_offset); +} + +void pool_bilin2d_backward( // "in" and "out" w.r.t. forward pass + torch::Tensor in_vid_grad, + torch::Tensor dists_grad, torch::Tensor inds_grad, + const torch::Tensor out_vid_grad, const torch::Tensor vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset){ + CHECK_INPUT(in_vid_grad); + CHECK_INPUT(dists_grad); + CHECK_INPUT(inds_grad); + CHECK_INPUT(out_vid_grad); + CHECK_INPUT(vid); + CHECK_INPUT(dists); + CHECK_INPUT(inds); + // pool_bilin2d_backward_cuda(in_vid_grad,dists_grad, + // inds_grad, + // out_vid_grad,vid,dists,inds, + // ps,stride0,pt,dilation,reflect_bounds,patch_offset); +} + +/*********************** + + + Python Bindings + + +***********************/ + +void init_pool(py::module &m){ + m.def("pool_int_forward", &pool_int_forward, + "WeightedPatchSum Forward (CUDA)"); + m.def("pool_int_backward", &pool_int_backward, + "WeightedPatchSum Backward (CUDA)"); + // m.def("pool_bilin2d_forward", &pool_bilin2d_forward, + // "WeightedPatchSum Forward (CUDA)"); + // m.def("pool_bilin2d_backward", &pool_bilin2d_backward, + // "WeightedPatchSum Backward (CUDA)"); + +} + diff --git a/lib/csrc/agg/pool_int_kernel.cu b/lib/csrc/agg/pool_int_kernel.cu new file mode 100644 index 0000000..2945919 --- /dev/null +++ b/lib/csrc/agg/pool_int_kernel.cu @@ -0,0 +1,358 @@ + +// #include +#include +#include +#include +#include +#include "../shared_kernel.cu" + +/**************************** + + Forward Pass + +****************************/ + +template +__global__ void pool_int_forward_kernel( + torch::PackedTensorAccessor32 out_vid, + torch::PackedTensorAccessor32 counts, + const torch::PackedTensorAccessor32 in_vid, + const torch::PackedTensorAccessor32 dists, + const torch::PackedTensorAccessor32 inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, + int patch_offset, int w_nW, int w_nHW, int q_per_thread){ + + // -- shapes -- + int B = in_vid.size(0); + int HD = in_vid.size(1); + int T = in_vid.size(2); + int F = in_vid.size(3); + int inH = in_vid.size(4); + int inW = in_vid.size(5); + int outH = out_vid.size(4); + int outW = out_vid.size(5); + int Q = inds.size(2); + int K = inds.size(3); + + // -- batching -- + int query_start = (threadIdx.x + blockDim.x*blockIdx.x)*q_per_thread; + // int query_start = blockIdx.x*blockDim.x+threadIdx.x; + int ki = blockIdx.y*blockDim.y+threadIdx.y; + int ihead = blockIdx.z/B; + int ibatch = (blockIdx.z-ihead*B) % B; + if (ki >= K){ return; } + + // -- pixel locations -- + int qi; + bool valid; + scalar_t pix,weight; + int ref_ti,nl_ti; + int ref[3],nl[3],nl_p[3]; + int wref[3],wref_p[3]; + int nW = (inW-1)/stride0+1; + int nHW = nW*((inH-1)/stride0+1); + int psHalf = (ps-1)/2; + + // -- across queries -- + for(int _qi = 0; _qi < q_per_thread; _qi++){ + + // -- query index -- + qi = query_start + _qi; + if (qi >= Q){ return; } + + // -- non-local weight -- + weight = dists[ibatch][ihead][qi][ki]; + if (weight < 1e-8){ continue; } + // if (ki == 0){ continue; } + + // -- write location -- + get_pixel_loc(wref,qi,ps,nW,nHW,outH,outW); + + // -- non-local index -- + get_pixel_loc(ref,qi,stride0,nW,nHW,inH,inW); +#pragma unroll + for (int _idx=0; _idx < 3; _idx++){ + nl[_idx] = ref[_idx] + inds[ibatch][ihead][qi][ki][_idx]; + } + + // -- always reflect anchor point -- + nl[0] = bounds(nl[0],T); + nl[1] = bounds(nl[1],inH); + nl[2] = bounds(nl[2],inW); + + // -- iterate over patches -- + for(int pi=0; pi < ps; pi++){ + for(int pj=0; pj < ps; pj++){ + + // -- write pixel index -- + wref_p[0] = wref[0]; + wref_p[1] = wref[1]+psHalf+(pi + patch_offset); + wref_p[2] = wref[2]+psHalf+(pj + patch_offset); + check_bounds(valid, wref_p, T, outH, outW); + if (not valid){ continue; } + + // -- normalize -- + if ((wref[0]==0) and (ibatch==0) and (ihead==0) and (ki==0)){ + atomicAdd(&counts[wref_p[1]][wref_p[2]],1); + } + + // -- non-local pixel index -- + nl_p[0] = nl[0]; + nl_p[1] = nl[1]+dilation*(pi + patch_offset); + nl_p[1] = reflect_bounds ? bounds(nl_p[1],inH) : nl_p[1]; + nl_p[2] = nl[2]+dilation*(pj + patch_offset); + nl_p[2] = reflect_bounds ? bounds(nl_p[2],inW) : nl_p[2]; + check_bounds(valid, nl_p, T, inH, inW); + if (not valid){ continue; } + + // -- iterate over loop -- + for(int pk = 0; pk < pt; pk++){ + + // -- time is always valid -- + ref_ti = wref_p[0] + pk; + nl_ti = reflect_bounds ? bounds(nl_p[0]+pk,T) : (nl_p[0]+pk); + valid = (nl_ti >= 0) && (nl_ti < T) and (ref_ti < T); + if (not valid){ continue; } + + // -- channels -- + for(int iftr = 0; iftr < F; iftr++){ + + // -- fill -- + pix = weight*in_vid[ibatch][ihead][nl_ti][iftr][nl_p[1]][nl_p[2]]; + atomicAdd(&out_vid[ibatch][ihead][ref_ti][iftr][wref_p[1]][wref_p[2]],pix); + + } // nfeatures-loop + } // pt-loop + }} // pi,pj + } // query-loop +} + +void pool_int_forward_cuda( + torch::Tensor out_vid, torch::Tensor counts, + const torch::Tensor in_vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, + bool reflect_bounds, int patch_offset){ + + // -- unpack -- + int B = inds.size(0); + int HD = inds.size(1); + int Q = inds.size(2); + int K = inds.size(3); + int q_per_thread = 2; + + // -- output dimensions -- + // int psHalf = (ps-1)/2; + int inH = in_vid.size(4); + int inW = in_vid.size(5); + int outH = out_vid.size(4); + int outW = out_vid.size(5); + int w_nW = (outW-1)/ps+1; + int w_nH = (outH-1)/ps+1; + int w_nHW = w_nW*w_nH; + int nW = (inW-1)/stride0+1; + int nH = (inH-1)/stride0+1; + // fprintf(stdout,"w_nH,w_nW: %d,%d\n",w_nH,w_nW); + // fprintf(stdout,"nH,nW: %d,%d\n",nH,nW); + + // -- kernel threads -- + int MAX_THREADS = 512;//1024 + int k_threads = 8; + int q_threads = MAX_THREADS/(k_threads); // num of queries threads per block + q_threads = min(Q,q_threads); + int q_blocks = (Q-1)/(q_per_thread*q_threads)+1; + int k_blocks = (K-1)/(k_threads)+1; + dim3 nthreads(q_threads,k_threads); + // fprintf(stdout, + // "ps,pt,stride0,reflect_bounds,dilation,patch_offset: %d,%d,%d,%d,%d,%d\n", + // ps,pt,stride0,reflect_bounds,dilation,patch_offset); + // -- kernel blocks -- + dim3 nblocks(q_blocks,k_blocks,B*HD); + + + // // -- kernel threads -- + // int MAX_THREADS = 1024; + // int q_threads = MAX_THREADS/(ps*ps); // num of queries threads per block + // q_threads = min(Q,q_threads); + // int q_blocks = (Q-1)/(q_per_thread*q_threads)+1; + // dim3 nthreads(q_threads,ps,ps); + // // fprintf(stdout,"ps,reflect_bounds,patch_offset: %d,%d,%d\n",ps,reflect_bounds,patch_offset); + + // // -- kernel blocks -- + // dim3 nblocks(q_blocks,B,HD); + + // -- launch kernel -- + AT_DISPATCH_FLOATING_TYPES(in_vid.type(), "pool_int_forward_kernel", ([&] { + pool_int_forward_kernel<<>>( + out_vid.packed_accessor32(), + counts.packed_accessor32(), + in_vid.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + ps, stride0, pt, dilation, reflect_bounds, patch_offset, + w_nW, w_nHW, q_per_thread); + })); +} + + + +/************************************ + + Backward Pass (for Vid & Dists) + +*************************************/ + +template +__global__ void pool_int_backward_kernel( + torch::PackedTensorAccessor32 in_vid_grad, + torch::PackedTensorAccessor32 dists_grad, + const torch::PackedTensorAccessor32 out_vid_grad, + const torch::PackedTensorAccessor32 vid, + const torch::PackedTensorAccessor32 dists, + const torch::PackedTensorAccessor32 inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset, + int q_per_thread, int k_per_thread){ + + // -- shape -- + int B = dists.size(0); + int HD = dists.size(1); + int Q = dists.size(2); + int K = dists.size(3); + int T = out_vid_grad.size(2); + int F = out_vid_grad.size(3); + int H = out_vid_grad.size(4); + int W = out_vid_grad.size(5); + + // -- pixel indexing -- + int qi,ki; + int ref[3],ref_p[3],nl[3]; + int ref_ti,nl_ti; + bool valid; + float weight,pix_n,pix_m; + + // -- location to fill -- + int q_start = q_per_thread*(blockIdx.x*blockDim.x+threadIdx.x); + int k_start = 0; + int ihead = blockIdx.y/B; + int ibatch = (blockIdx.y-ihead*B); + int nW = (W-1)/stride0+1; + int nHW = nW*((H-1)/stride0+1); + + // -- cuda threads -- + int pi = threadIdx.y; + int pj = threadIdx.z; + + // -- across queries -- + for(int _qi = 0; _qi < q_per_thread; _qi++){ + + // -- query index -- + qi = q_start + _qi; + if (qi >= Q){ continue; } + get_pixel_loc(ref,qi,stride0,nW,nHW,H,W); + + // -- reference pixel index -- + ref_p[0] = ref[0]; + ref_p[1] = ref[1]+dilation*(pi + patch_offset); + ref_p[2] = ref[2]+dilation*(pj + patch_offset); + + // -- valid ref pixel only -- + check_bounds(valid, ref_p, T, H, W); + if (not valid){ continue; } + + for(int _ki = 0; _ki < k_per_thread; _ki++){ + + // -- non-local index -- + ki = k_start + _ki; + if (ki >= K){ continue; } + #pragma unroll + for (int _idx=0; _idx < 3; _idx++){ + nl[_idx] = ref[_idx] + inds[ibatch][ihead][qi][ki][_idx]; + } + + // -- reflect -- + nl[0] = bounds(nl[0],T); + nl[1] = bounds(nl[1],H); + nl[2] = bounds(nl[2],W); + + // -- non-local pixel index -- + nl[1] = nl[1]+dilation*(pi + patch_offset); + nl[1] = reflect_bounds ? bounds(nl[1],H) : nl[1]; + nl[2] = nl[2]+dilation*(pj + patch_offset); + nl[2] = reflect_bounds ? bounds(nl[2],W) : nl[2]; + + // -- valid non-local patches only -- + check_bounds(valid, nl, T, H, W); + if (not valid){ continue; } + + // -- non-local weight -- + weight = dists[ibatch][ihead][qi][ki]; + scalar_t acc_dists_grad = 0; + + for (int pk = 0; pk < pt; pk++){ + + // -- time is always valid -- + ref_ti = ref_p[0] + pk; + nl_ti = reflect_bounds ? bounds(nl[0]+pk,T) : (nl[0]+pk); + valid = (nl_ti >= 0) && (nl_ti < T) and (ref_ti < T); + if (not valid){ continue; } + + // -- num features -- + for (int iftr = 0; iftr < F; iftr++){ + pix_n = out_vid_grad[ibatch][ihead][ref_ti][iftr][ref_p[1]][ref_p[2]]; + pix_m = vid[ibatch][ihead][nl_ti][iftr][nl[1]][nl[2]]; + atomicAdd(&in_vid_grad[ibatch][ihead][nl_ti][iftr][nl[1]][nl[2]],weight*pix_n); + acc_dists_grad += pix_n*pix_m; + } + + } // pt + + // -- write dist grad -- + atomicAdd(&dists_grad[ibatch][ihead][qi][ki],acc_dists_grad); + + } // ki + } // qi +} + +void pool_int_backward_cuda( + torch::Tensor in_vid_grad, torch::Tensor dists_grad, + const torch::Tensor out_vid_grad, const torch::Tensor vid, + const torch::Tensor dists, const torch::Tensor inds, + int ps, int stride0, int pt, int dilation, bool reflect_bounds, int patch_offset){ + + // -- launch parameters -- + int B = dists.size(0); + int HD = dists.size(1); + int Q = dists.size(2); + int K = dists.size(3); + int q_per_thread = 1; + int k_per_thread = K; + // fprintf(stdout, + // "ps,stride0,pt,dilation,reflect_bounds,patch_offset: %d,%d,%d,%d,%d,%d\n", + // ps,stride0,pt,dilation,reflect_bounds,patch_offset); + + // -- kernel threads -- + int MAX_THREADS = 768; + int q_threads = MAX_THREADS/(ps*ps); // num of queries threads per block + q_threads = min(Q,q_threads); + int q_blocks = (Q-1)/(q_per_thread*q_threads)+1; + int k_blocks = (K-1)/k_per_thread+1; + dim3 nthreads(q_threads,ps,ps); + dim3 nblocks(q_blocks, HD*B); + + // fprintf(stdout,"q_threads: %d\n",q_threads); + + // launch kernel + AT_DISPATCH_FLOATING_TYPES(in_vid_grad.type(), "pool_int_backward_vid_kernel", ([&] { + pool_int_backward_kernel<<>>( + in_vid_grad.packed_accessor32(), + dists_grad.packed_accessor32(), + out_vid_grad.packed_accessor32(), + vid.packed_accessor32(), + dists.packed_accessor32(), + inds.packed_accessor32(), + ps, stride0, pt, dilation, reflect_bounds, patch_offset, + q_per_thread, k_per_thread); + })); + +} + diff --git a/lib/csrc/agg/wpsum_int_kernel.cu b/lib/csrc/agg/wpsum_int_kernel.cu index af3745d..6a7d237 100644 --- a/lib/csrc/agg/wpsum_int_kernel.cu +++ b/lib/csrc/agg/wpsum_int_kernel.cu @@ -33,13 +33,14 @@ __global__ void wpsum_int_forward_kernel( int K = inds.size(3); // -- batching -- - // int query_start = (threadIdx.x + blockDim.x*blockIdx.x)*q_per_thread; - int query_start = blockIdx.x*blockDim.x+threadIdx.x; + int query_start = (threadIdx.x + blockDim.x*blockIdx.x)*q_per_thread; + // int query_start = blockIdx.x*blockDim.x+threadIdx.x; int ki = blockIdx.y*blockDim.y+threadIdx.y; int ihead = blockIdx.z/B; int ibatch = (blockIdx.z-ihead*B) % B; // int ibatch = blockIdx.y; // int ihead = blockIdx.z; + if (ki >= K){ return; } // // -- cuda threads -- // int pi = threadIdx.y; @@ -88,6 +89,10 @@ __global__ void wpsum_int_forward_kernel( nl[_idx] = ref[_idx] + inds[ibatch][ihead][qi][ki][_idx]; } + // -- check "inf" (but it won't be inf sometimes) -- + valid = (abs(nl[1]) < 1e7) and (abs(nl[2]) < 1e7); + if (not(valid)){ continue; } + // -- always reflect anchor point -- nl[0] = bounds(nl[0],T); nl[1] = bounds(nl[1],H); diff --git a/lib/csrc/pybind.cpp b/lib/csrc/pybind.cpp index 68b066c..206b44a 100644 --- a/lib/csrc/pybind.cpp +++ b/lib/csrc/pybind.cpp @@ -16,6 +16,7 @@ void init_non_local_inds(py::module &); // -- agg -- void init_wpsum(py::module &); +void init_pool(py::module &); void init_gather(py::module &); void init_scatter(py::module &); @@ -36,6 +37,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // -- agg -- init_wpsum(m); + init_pool(m); init_gather(m); init_scatter(m); diff --git a/lib/csrc/search/refinement_bilin2d_kernel.cu b/lib/csrc/search/refinement_bilin2d_kernel.cu index a9ff05d..407bc2f 100644 --- a/lib/csrc/search/refinement_bilin2d_kernel.cu +++ b/lib/csrc/search/refinement_bilin2d_kernel.cu @@ -107,6 +107,12 @@ __global__ void refinement_bilin2d_forward_kernel( prop_center[1] = ref_patch[2] + flows[ibatch][ihead_f][ti][nh][nw][ki][2]; prop_patch[0] = bounds(prop_patch[0],T); + // -- possibly illegal flows -- + valid = abs(flows[ibatch][ihead_f][ti][nh][nw][ki][1]) < 1e8; + valid = valid and abs(flows[ibatch][ihead_f][ti][nh][nw][ki][2]) < 1e8; + if (not(valid)){ continue; } + + // -- bounding -- reflect[ibatch][ihead_f][ti][nh][nw][ki][0] = not check_bound(prop_center[0],H); reflect[ibatch][ihead_f][ti][nh][nw][ki][1] = not check_bound(prop_center[1],W); prop_center[0] = bounds(prop_center[0],H); diff --git a/lib/csrc/search/refinement_int_kernel.cu b/lib/csrc/search/refinement_int_kernel.cu index a2cded1..5a70dc3 100644 --- a/lib/csrc/search/refinement_int_kernel.cu +++ b/lib/csrc/search/refinement_int_kernel.cu @@ -102,6 +102,10 @@ __global__ void refinement_forward_kernel( prop_center[0] = bounds(prop_center[0],H); prop_center[1] = bounds(prop_center[1],W); + // -- possibly illegal flows -- + valid = abs(flows[ibatch][ihead_f][ti][nh][nw][ki][1]) < 1e8; + valid = valid and abs(flows[ibatch][ihead_f][ti][nh][nw][ki][2]) < 1e8; + if (not(valid)){ continue; } // -- search region offsets -- set_search_offsets(wrOff_h, wrOff_w, diff --git a/lib/stnls/agg/__init__.py b/lib/stnls/agg/__init__.py index 2b3d5f2..2dc372b 100644 --- a/lib/stnls/agg/__init__.py +++ b/lib/stnls/agg/__init__.py @@ -1,6 +1,7 @@ # -- modules -- from . import wpsum as wpsum_f +from . import pool as pool_f from . import gather as gather_f from . import scatter as scatter_f from . import scatter_labels as scatter_labels_f @@ -12,6 +13,7 @@ # -- functional api -- wpsum = wpsum_f._apply +pool = pool_f._apply gather = gather_f._apply scatter = scatter_f._apply scatter_labels = scatter_labels_f.run @@ -23,3 +25,4 @@ NonLocalGather = gather_f.NonLocalGather NonLocalScatter = scatter_f.NonLocalScatter WeightedPatchSum = wpsum_f.WeightedPatchSum +PooledPatchSum = pool_f.PooledPatchSum diff --git a/lib/stnls/agg/pool.py b/lib/stnls/agg/pool.py new file mode 100644 index 0000000..257e85b --- /dev/null +++ b/lib/stnls/agg/pool.py @@ -0,0 +1,273 @@ +""" + +Usage: see scripts/example_attn.py + +""" + +# -- python -- +import torch as th +from einops import rearrange + +# -- cpp cuda kernel -- +import stnls_cuda + +# -- api -- +from stnls.utils import extract_pairs + +def allocate_vid(vid_shape,device): + vid = th.zeros(vid_shape,device=device,dtype=th.float32) + return vid + +def allocate_patches(b,nq,nhead,pt,c,ps,device): + patches = th.zeros((b,nq,nhead,pt,c,ps,ps),device=device,dtype=th.float32) + return patches + +def get_inds(inds,itype): + inds = inds.contiguous() + if itype == "int" and th.is_floating_point(inds): + return inds.round().int() + elif itype in ["float","2d","3d"] and not(th.is_floating_point(inds)): + return inds.float() + else: + return inds + +class PooledPatchSumFunction(th.autograd.Function): + + @staticmethod + def forward(ctx, vid, weights, flows, ps, stride0, + pt=1, dilation=1, reflect_bounds=True, use_adj=False, itype="float"): + """ + vid = [BatchSize,nHeads or 1,T,C,H,W] + weights = [BatchSize,nHeads,NumQueries,K] + flows = [BatchSize,nHeads or 1,NumQueries,K,3] + ps = patchsize + pt = patchsize_time (forward only) + """ + + # -- add head dim if 1 -- + vid_in_dim = vid.ndim + total_color = vid.shape[-3] + bsize,nheads = weights.shape[:2] + if vid.ndim == 5: + if (total_color % nheads) == 0: + vid = rearrange(vid,'b t (H c) h w -> b H t c h w',H=nheads) + else: + vid = rearrange(vid,'b t c h w -> b 1 t c h w') + if flows.ndim == 4: flows = flows[:,None] # add heads dim + + # -- unpack -- + device = weights.device + B,HD,T,nH,nW,K = weights.shape + wshape = weights.shape + vid = vid.contiguous() + flows = get_inds(flows,itype) + + # -- shape output -- + inF,inH,inW = vid.shape[-3:] + psHalf = (ps-1)//2+1 + outH = ps*nH + outW = ps*nW + # print("ps,nH,nW,outH,outW: ",ps,nH,nW,outH,outW) + out_shape = (B,HD,T,inF,outH,outW) + + # -- allocate -- + dtype = vid.dtype + out_vid = th.zeros(out_shape,device=device,dtype=dtype) + counts = th.zeros_like(out_vid[0,0,0,0,:,:]).type(th.int) + patch_offset = 0 if use_adj else -(ps//2) + # print(patch_offset) + + # -- view -- + Q = T*nH*nW + weights = weights.view(B,HD,Q,K) + flows = flows.view(B,HD,Q,K,3) + + # -- exec -- + fwd_fxn = stnls_cuda.pool_int_forward + # if flows.dtype == th.int: + # fwd_fxn = stnls_cuda.pool_int_forward + # else: + # # flows[...,1:] = flows[...,1:].int()+1 + # fwd_fxn = stnls_cuda.pool_bilin2d_forward + fwd_fxn(out_vid, counts, vid, weights, flows, + ps, stride0, pt, dilation, reflect_bounds, patch_offset) + eps = 1e-10 + # print(out_vid.shape,vid.shape) + # print(out_vid.sum((-2,-1))) + # # print(out_vid[0,0,0,0,:,:].sum((-2))) + # # print(out_vid[0,0,0,0,:,:].sum((-1))) + # print(out_vid[0,0,0,0,:5,:5]) + # print(th.where(out_vid==1)) + # print(out_vid[0,0,0,0,:7,:7]) + # print(out_vid[0,0,-1,0,-7:,-7:]) + # print(out_vid[0,0,0,:,6,233]) + # print(out_vid[0,0,0,:,7,229]) + + # -- normalize -- + H,W = vid.shape[-2:] + # print(counts) + # print(counts.sum(-1)) + # print(counts.sum(-2)) + # exit() + # print("counts [min,max]: ",counts.min().item(),counts.max().item()) + # print("[pre] out_vid [min,max]: ",out_vid.min().item(),out_vid.max().item()) + counts = counts.view((1,1,1,1,outH,outW)) + out_vid = out_vid / (counts+eps) + # print("out_vid [min,max]: ",out_vid.min().item(),out_vid.max().item()) + assert th.all(counts>1e-3) + # exit() + + # -- backward -- + ctx.save_for_backward(weights,flows,vid,counts) + ctx.vid_in_dim = vid_in_dim + ctx.itype = itype + ctx.ps,ctx.pt = ps,pt + ctx.stride0 = stride0 + ctx.vid_shape = vid.shape + ctx.wshape = wshape + ctx.dilation = dilation + ctx.use_adj = use_adj + ctx.reflect_bounds = reflect_bounds + ctx.nheads = nheads + + return out_vid + + @staticmethod + def backward(ctx, grad_out_vid): + + # -- unpack -- + weights,flows,vid,counts = ctx.saved_tensors + ps,pt = ctx.ps,ctx.pt + stride0 = ctx.stride0 + vid_shape = ctx.vid_shape + dilation = ctx.dilation + use_adj = ctx.use_adj + reflect_bounds = ctx.reflect_bounds + HD = ctx.nheads + itype = ctx.itype + patch_offset = 0 if use_adj else -(ps//2) + + # -- normalize -- + H,W = counts.shape[-2:] + grad_out_vid = grad_out_vid / counts.view(1,1,1,H,W) + + # -- alloc -- + grad_weights = th.zeros_like(weights) + grad_flows = th.zeros_like(flows) if itype == "float" else None + grad_in_vid = th.zeros_like(grad_out_vid) + + # -- info -- + # print("ps,stride0,pt,dilation,reflect_bounds,patch_offset: ", + # ps,stride0,pt,dilation,reflect_bounds,patch_offset) + + # th.cuda.synchronize() + # print(grad_out_vid[th.where(grad_out_vid>0)]) + # print(grad_out_vid.sum()) + # print(grad_weights[0,0]) + + # -- video backward -- + if itype == "int": + bwd_fxn = stnls_cuda.pool_int_backward + bwd_fxn(grad_in_vid,grad_weights, + grad_out_vid,vid,weights,flows, + ps,stride0,pt,dilation, + reflect_bounds,patch_offset) + # elif not(flows.requires_grad): + # bwd_fxn = stnls_cuda.wpsum_bilin2d_backward + # bwd_fxn(grad_in_vid,grad_weights, + # grad_out_vid,vid,weights,flows, + # ps,stride0,pt,dilation, + # reflect_bounds,patch_offset) + else: + bwd_fxn = stnls_cuda.pool_bilin2d_backward + bwd_fxn(grad_in_vid,grad_weights,grad_flows, + grad_out_vid,vid,weights,flows, + ps,stride0,pt,dilation, + reflect_bounds,patch_offset) + + # print(th.where(grad_weights[0,0].abs()>0)) + # print(grad_weights[th.where(grad_weights.abs()>0)]) + # print(grad_out_vid.sum(),grad_weights.sum()) + + # -- shaping vid -- + vid_in_dim = ctx.vid_in_dim + if vid_in_dim == 5: + grad_in_vid = rearrange(grad_in_vid,'b hd t c h w -> b t (hd c) h w') + grad_in_vid = grad_in_vid.contiguous() + + # -- shaping weight,flows -- + grad_weights = grad_weights.reshape(ctx.wshape) + if ctx.itype == "float": + grad_flows = grad_flows.reshape(ctx.wshape+(3,)) + else: + grad_flows = None + + return grad_in_vid,grad_weights,grad_flows,None,None,None,\ + None,None,None,None,None,None,None,None,None + +class PooledPatchSum(th.nn.Module): + # [video -> patches] @ flows + + def __init__(self, ps, stride0, pt=1, dilation=1, + reflect_bounds=True, use_adj=False, itype="float"): + super().__init__() + _vars = ["ps","stride0","pt","dilation","reflect_bounds","use_adj","itype"] + self._vars = _vars + for var in _vars: + setattr(self,var,eval(var)) + + def forward(self, vid, weights, flows): + inputs = [getattr(self,var) for var in self._vars] + vid_out = PooledPatchSumFunction.apply(vid,weights,flows,*inputs) + return vid_out + + def flops(self, nrefs, chnls_per_head, nheads, k): + + # -- init -- + flops = 0 + + # -- unpack -- + chnls = chnls_per_head + ps,pt = self.ps,self.pt + + # -- compute weighted patch sum -- + flops_per_patch = 2 * (chnls * ps * ps * pt) # multi weight & add to accumulate + flops_per_ref = flops_per_patch * k # accumulate over "k" patches + flops = flops_per_ref * nrefs * nheads# do for each reference + + return flops + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# [Direct API] stnls.agg.wpsum(...) +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def _apply(vid, weights, flows, ps, stride0, + pt=1, dilation=1,reflect_bounds=True, use_adj=False): + # wrap "new (2018) apply function + # https://discuss.pytorch.org #13845/17 + fxn = PooledPatchSumFunction.apply + return fxn(vid,weights,flows,ps,stride0, + pt,dilation,reflect_bounds,use_adj) + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# [Python Dict API] stnls.agg.wpsum(pydict) +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def extract_config(cfg,restrict=True): + pairs = {"ps":3,"stride0":1,"pt":1,"dilation":1, + "reflect_bounds":True, "use_adj":False, "itype":"float"} + return extract_pairs(cfg,pairs,restrict=restrict) + +def init(cfg): + cfg = extract_config(cfg,False) + reducer = PooledPatchSum( + cfg.ps, cfg.stride0, pt=cfg.pt, dilation=cfg.dilation, + reflect_bounds=cfg.reflect_bounds,use_adj=cfg.use_adj,itype=cfg.itype) + return reducer + + + diff --git a/lib/stnls/agg/scatter_labels.py b/lib/stnls/agg/scatter_labels.py index 402ac47..2d7b4a9 100644 --- a/lib/stnls/agg/scatter_labels.py +++ b/lib/stnls/agg/scatter_labels.py @@ -39,11 +39,11 @@ def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws): # -- number of maximum possible groups a single patch can belong to -- Wt_num = T if wt > 0 else 1 # Ws_num = ws*ws - wsNum = (ws-1)//stride0+1 + wsNum = (ws)//stride0+1 Ws_num = wsNum*wsNum if full_ws: Ws_num += 2*wsNum*(wsNum//2) + (wsNum//2)**2 S = Wt_num*Ws_num - print(S,ws,wt,stride0,stride1,full_ws) + # print(S,ws,wt,stride0,stride1,full_ws) # -- prepare -- labels = -th.ones((B,HD,Q,K),device=flows.device,dtype=th.int) @@ -55,9 +55,9 @@ def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws): # -- check -- nvalid = (names[...,0] >= 0).float().sum(2) - if full_ws: - print(int(nvalid.sum().item()),Q*K) - # assert(int(nvalid.sum().item()) == Q*K) + # if full_ws: + # print(int(nvalid.sum().item()),Q*K) + # # assert(int(nvalid.sum().item()) == Q*K) return names,labels diff --git a/lib/stnls/agg/scatter_tensor.py b/lib/stnls/agg/scatter_tensor.py index 283fa39..7b63426 100644 --- a/lib/stnls/agg/scatter_tensor.py +++ b/lib/stnls/agg/scatter_tensor.py @@ -17,7 +17,7 @@ # -- cpp cuda kernel -- import stnls_cuda -def run(tensor,flows_k,labels,stride0,stride1,H,W): +def run(tensor,flows_k,labels,stride0,stride1,H,W,invalid=th.inf): # -- unpack shapes -- B,HD,T,nH0,nW0,K = tensor.shape[:6] @@ -36,7 +36,7 @@ def run(tensor,flows_k,labels,stride0,stride1,H,W): # -- prepare -- shape = (B,HD,Q1,S,M) - scatter_tensor = -th.inf*th.ones(shape,device=labels.device,dtype=tensor.dtype) + scatter_tensor = invalid*th.ones(shape,device=labels.device,dtype=tensor.dtype) stnls_cuda.scatter_tensor_forward(scatter_tensor,tensor,labels,flows_k, stride0,stride1,H,W) diff --git a/lib/stnls/dev/__init__.py b/lib/stnls/dev/__init__.py index be1514e..0d11261 100644 --- a/lib/stnls/dev/__init__.py +++ b/lib/stnls/dev/__init__.py @@ -1,2 +1,3 @@ # --- api --- from . import search +from . import slic diff --git a/lib/stnls/dev/slic/__init__.py b/lib/stnls/dev/slic/__init__.py new file mode 100644 index 0000000..71d10de --- /dev/null +++ b/lib/stnls/dev/slic/__init__.py @@ -0,0 +1,262 @@ +""" + + Slic is easy with our packages + +""" + +# -- basic -- +import torch as th +import numpy as np +from einops import rearrange,repeat +from easydict import EasyDict as edict +from dev_basics.utils import vid_io + +# -- exps -- +from dev_basics.utils.misc import set_seed + +# -- optical flow -- +from dev_basics import flow + +# -- data -- +import data_hub + +# -- non-local opts -- +import stnls + +# -- benchmarking -- +from dev_basics.utils.timer import ExpTimer,TimeIt +from dev_basics.utils.gpu_mem import GpuMemer,MemIt + +# -- view segmentation -- +from torchvision.utils import draw_segmentation_masks +from skimage.segmentation import mark_boundaries + + +def load_video(cfg): + device = "cuda:0" + data,loaders = data_hub.sets.load(cfg) + indices = data_hub.filter_subseq(data[cfg.dset],cfg.vid_name,0,cfg.nframes) + vid = data[cfg.dset][indices[0]]['clean'][None,:].to(device)/255. + # F = 32 + # B,T,_,H,W = vid.shape + # vid = th.randn((B,T,F,H,W),device=device,dtype=vid.dtype) + return vid + +def append_grid(vid,M,S): + B,T,F,H,W = vid.shape + dtype,device = vid.dtype,vid.device + grid_y, grid_x = th.meshgrid(th.arange(0, H, dtype=dtype, device=device), + th.arange(0, W, dtype=dtype, device=device)) + grid = th.stack((grid_x, grid_y), -1).float() # 2, W(x), H(y) + grid = repeat(grid,'h w two -> b t two h w',b=B,t=T) + vid = th.cat([vid,M/S*grid],2) + return vid + +def slic_select(vid,ws): + + # -- config -- + ps = 1 + # ws = 3 + wt = 0 + stride0 = 8 + ws = 2*stride0-2 + # stride0,ws = 3,5 + stride1 = 1 + K0 = 1 + softmax_weight = 10. + k = -1 + full_ws = True + use_flow = False + M = 0.1 + use_rand = True + + # -- compute search window -- + B,T,F,H,W = vid.shape + search = stnls.search.NonLocalSearch(ws,wt,ps,k, + nheads=1,dist_type="l2", + stride0=stride0, + self_action="anchor_self", + full_ws=full_ws,itype="int") + + flows = flow.orun(vid,use_flow,ftype="cv2") + flows = stnls.nn.search_flow(flows.fflow,flows.bflow,wt,stride0) + flows = flows[:,None].round().int() + + vid = append_grid(vid,M,stride0) + dists,flows_k = search(vid,vid,flows) + # inds = stnls.utils.misc.flow2inds(flows_k,stride0) + + # -- scattering top-K=1 -- + K0 = 1 + gather_weights = dists + names,labels = stnls.agg.scatter_labels(flows,flows_k,ws,wt, + stride0,stride1,H,W,full_ws) + gather_labels = labels.reshape_as(gather_weights) + scatter_weights = stnls.agg.scatter_tensor(gather_weights,flows_k,labels, + stride0,stride1,H,W) + scatter_flows_k = stnls.agg.scatter_tensor(flows_k,flows_k,labels, + stride0,stride1,H,W) + scatter_labels = stnls.agg.scatter_tensor(gather_labels,flows_k,labels, + stride0,stride1,H,W) + + # -- topk -- + scatter_flows_k = -scatter_flows_k + s_weight,s_flows_k,s_labels = stnls.agg.scatter_topk(scatter_weights,scatter_flows_k, + scatter_labels,K0, + descending=False) + + # -- prepare weights and flows -- + pooled,weights,flows_k = slic_pooling(vid,s_weight,s_flows_k,s_labels, + ps,stride0,stride1,K0, + softmax_weight,"wpsum") + # print(th.cat([weights[...,None],flows_k],-1)) + + # -- refine -- + assert pooled.shape[-2:] == vid.shape[-2:],"Same Spatial Dim [H x W]" + wr,k,kr = 1,1,1. + refine = stnls.search.RefineSearch(ws, wt, wr, k, kr, ps, nheads=1, + stride0=stride0, dist_type="l2", itype="int") + if use_rand: + pooled = th.rand_like(pooled) + vid = th.rand_like(vid) + dists,flows_k = refine(pooled,vid,flows_k) + # weights = th.softmax(-softmax_weight*dists,-1) + # print(vid.shape,dists.shape,flows_k.shape) + + # -- flows to mask -- + # inds = inds2labels(s_flows_k,cfg,H,W) + # print(th.cat([dists[...,None],flows_k],-1)) + inds = stnls.utils.misc.flow2inds(flows_k,stride0).long() + inds = rearrange(inds,'b hd t h w 1 tr -> (b hd) (t h w) tr') + select = inds[:,:,0]*H*W + inds[:,:,1]*W + inds[:,:,2] + print("select.shape: ",select.shape) + # print(inds.shape) + # # print(inds) + # # # print(inds.shape) + + mask = th.zeros(B,T*H*W).to(select.device) + for bi in range(B): + mask[bi,select[bi]] = 1 + # print(mask.shape,select.shape) + # mask = mask.scatter_(1,select,1) + mask = mask.reshape(B,1,H,W) + # mask = th.zeros((B,H*W)) + # mask[inds] = 1 + # print(H*W/mask.sum()) + + # exit() + + return mask + +def run_slic(vid,flows,cfg): + + # -- compute search window -- + B,T,F,H,W = vid.shape + search = stnls.search.NonLocalSearch(cfg.ws,cfg.wt,cfg.ps,cfg.k, + nheads=1,dist_type="l2", + stride0=cfg.stride0, + self_action="anchor_self", + full_ws=cfg.full_ws,itype="int") + vid = append_grid(vid,cfg.M,cfg.stride0) + dists,flows_k = search(vid,vid,flows) + # inds = stnls.utils.misc.flow2inds(flows_k,cfg.stride0) + + # -- scattering top-K=1 -- + K0 = 1 + gather_weights = dists + names,labels = stnls.agg.scatter_labels(flows,flows_k,cfg.ws,cfg.wt, + cfg.stride0,cfg.stride1,H,W,cfg.full_ws) + gather_labels = labels.reshape_as(gather_weights) + scatter_weights = stnls.agg.scatter_tensor(gather_weights,flows_k,labels, + cfg.stride0,cfg.stride1,H,W) + scatter_flows_k = stnls.agg.scatter_tensor(flows_k,flows_k,labels, + cfg.stride0,cfg.stride1,H,W) + scatter_labels = stnls.agg.scatter_tensor(gather_labels,flows_k,labels, + cfg.stride0,cfg.stride1,H,W) + + # -- topk -- + scatter_flows_k = -scatter_flows_k + s_weight,s_flows_k,s_labels = stnls.agg.scatter_topk(scatter_weights,scatter_flows_k, + scatter_labels,K0, + descending=False) + + # -- pooling -- + pooled,_,_ = slic_pooling(vid,s_weight,s_flows_k,s_labels, + cfg.ps,cfg.stride0,cfg.stride1,K0,cfg.softmax_weight) + return pooled[:,:,:3],s_flows_k + + +def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0, + softmax_weight,pool_method="pool"): + + # -- prepare weights and flows -- + B,T,F,H,W = vid.shape + HD = s_weights.shape[1] + s_weights = s_weights.reshape(B,HD,T,H,W,K0) + s_flows_k = s_flows_k.reshape(B,HD,T,H,W,K0,3) + s_labels = s_labels.reshape(B,HD,T*H*W,-1) + + # -- run scatters -- + weights = stnls.agg.scatter_tensor(s_weights,s_flows_k,s_labels, + stride1,stride0,H,W) + flows_k = stnls.agg.scatter_tensor(s_flows_k,s_flows_k,s_labels, + stride1,stride0,H,W) + + # -- reshape -- + K = weights.shape[-1] + nH = (H-1)//stride0+1 + nW = (W-1)//stride0+1 + weights = weights.reshape(B,HD,T,nH,nW,K) + flows_k = flows_k.reshape(B,HD,T,nH,nW,K,3) + + # -- renormalize weights -- + weights = th.softmax(-softmax_weight*weights,-1) + + # -- aggregate -- + ps = stride0 + if pool_method == "pool": + agg = stnls.agg.PooledPatchSum(ps,stride0,itype="int") + elif pool_method == "wpsum": + ps = stride0*2 + agg = stnls.agg.WeightedPatchSum(ps,stride0,itype="int") + else: + raise ValueError(f"Uknown pool method [{pool_method}]") + vout = agg(vid,weights,flows_k) + vout = rearrange(vout,'b hd t c h w -> b t (hd c) h w') + + return vout,weights,flows_k + +def inds2labels(s_flows_k,cfg,H,W): + + # -- get segmentation labels -- + nH0,nW0 = (H-1)//cfg.stride0+1,(W-1)//cfg.stride0+1 + nH,nW = (H-1)//cfg.stride1+1,(W-1)//cfg.stride1+1 + shape_str = 'b hd (t nh nw) k tr -> b hd t nh nw k tr' + s_flows_k = rearrange(s_flows_k,shape_str,nh=nH,nw=nW) + s_inds = stnls.utils.misc.flow2inds(s_flows_k,cfg.stride1) + nH0,nW0 = H//cfg.stride0,W//cfg.stride0 + s_inds = s_inds[:,0,...,0,:].contiguous() # 1 head, 1 k + stnls.utils.misc.reflect_inds(s_inds,H,W) + + # -- labels -- + seg_labels = s_inds[...,0]*nH0*nW0 + seg_labels += th.div(s_inds[...,1],cfg.stride0,rounding_mode="floor")*nW0 + seg_labels += th.div(s_inds[...,2],cfg.stride0,rounding_mode="floor") + + # -- fill invalid -- + valid = th.logical_and(seg_labels<100000,seg_labels>-100000) + S = seg_labels[th.where(valid)].max() + seg_labels[th.where(~valid)] = S+1 + + # -- view -- + # print(seg_labels.shape) + # print(seg_labels[0,0,-5:,-5:]) + + return seg_labels + +def labels2masks(labels): + S = labels.max()+1 + masks = th.zeros([S,]+list(labels.shape),dtype=th.bool).to(labels.device) + for si in range(S): + masks[si] = labels==si + return masks diff --git a/scripts/slic.py b/scripts/slic.py index 2c494cd..3952ab3 100644 --- a/scripts/slic.py +++ b/scripts/slic.py @@ -61,13 +61,13 @@ def run_exp(cfg): vid = load_video(cfg).half() vid = append_grid(vid,cfg.M,cfg.stride0) B,T,F,H,W = vid.shape - print("vid.shape: ",vid.shape) + # print("vid.shape: ",vid.shape) # -- compute flows -- flows = flow.orun(vid,cfg.flow,ftype="cv2") flows = stnls.nn.search_flow(flows.fflow,flows.bflow,cfg.wt,cfg.stride0) flows = flows[:,None].round().int() - print(flows.shape) + # print(flows.shape) # -- benchmark -- timer,memer = ExpTimer(),GpuMemer() @@ -81,7 +81,8 @@ def run_exp(cfg): for pooling_type in pooling_grid: with TimeIt(timer,pooling_type): with MemIt(memer,pooling_type): - pooled_p,seg_p = run_pooling(cfg,vid,flows,pooling_type) + pooled_p,seg_p = run_pooling(cfg,vid,flows,pooling_type, + cfg.pooling_ksize,cfg.stride0) pooled[pooling_type] = pooled_p seg[pooling_type] = seg_p @@ -95,11 +96,11 @@ def run_exp(cfg): return vid,pooled,seg -def run_pooling(cfg,vid,flows,pooling_type): - ws,stride0 = cfg.ws,cfg.stride0 - ksize = ws - ksize = stride0 - stride = 1#stride0//2 +def run_pooling(cfg,vid,flows,pooling_type,ksize,stride): + ws = cfg.ws + # ksize = ws + # ksize = stride0 + # stride = 1#stride0//2 # stride = stride0 B = vid.shape[0] @@ -118,11 +119,11 @@ def run_standard(pool_fxn,vid,ksize,stride): elif pooling_type == "slic": pool_fxn = th.nn.functional.avg_pool2d pooled,seg = run_slic(vid,flows,cfg) - pooled = run_standard(pool_fxn,pooled,ksize,stride) + # pooled = run_standard(pool_fxn,pooled,ksize,stride) elif pooling_type == "nls": pool_fxn = th.nn.functional.avg_pool2d pooled,seg = run_nls(vid,flows,cfg) - pooled = run_standard(pool_fxn,pooled,ksize,stride) + # pooled = run_standard(pool_fxn,pooled,ksize,stride) else: raise ValueError("Uknown pooling type.") return pooled,seg @@ -137,11 +138,14 @@ def run_nls(vid,flows,cfg): self_action="anchor_self", full_ws=full_ws,itype="int") dists,flows_k = search(vid,vid,flows) - weights = th.softmax(-dists,-1) + weights = th.softmax(-cfg.softmax_weight*dists,-1) + # print(weights) # -- aggregate -- - ps = int(cfg.stride0*1.75) - agg = stnls.agg.WeightedPatchSum(ps,cfg.stride0,itype="int") + ps = cfg.stride0 + # ps = int(cfg.stride0*1.75) + # agg = stnls.agg.WeightedPatchSum(ps,cfg.stride0,itype="int") + agg = stnls.agg.PooledPatchSum(ps,cfg.stride0,itype="int") vout = agg(vid,weights,flows_k) vout = rearrange(vout,'b hd t c h w -> b t (hd c) h w') @@ -167,7 +171,8 @@ def run_slic(vid,flows,cfg): # -- scattering top-K=1 -- K0 = 1 - gather_weights = th.softmax(-dists,-1) + # gather_weights = th.softmax(-dists,-1) + gather_weights = dists # timer,memer = ExpTimer(),GpuMemer() # with TimeIt(timer,"labels"): # with MemIt(memer,"labels"): @@ -175,7 +180,9 @@ def run_slic(vid,flows,cfg): cfg.stride0,cfg.stride1,H,W,cfg.full_ws) # print(timer,memer) # print(labels.min().item(),labels.max().item()) - print("[scattering]: ",gather_weights.shape,flows_k.shape,labels.shape) + # print("[scattering]: ",gather_weights.shape,flows_k.shape,labels.shape) + # print(gather_weights[0,0,0,0,0]) + # print(labels[0,0,0]) gather_labels = labels.reshape_as(gather_weights) scatter_weights = stnls.agg.scatter_tensor(gather_weights,flows_k,labels, cfg.stride0,cfg.stride1,H,W) @@ -183,7 +190,7 @@ def run_slic(vid,flows,cfg): cfg.stride0,cfg.stride1,H,W) scatter_labels = stnls.agg.scatter_tensor(gather_labels,flows_k,labels, cfg.stride0,cfg.stride1,H,W) - print("[a]: ",scatter_flows_k.shape,flows_k.shape,scatter_labels.shape) + # print("[a]: ",scatter_flows_k.shape,flows_k.shape,scatter_labels.shape) # -- checking in -- @@ -197,10 +204,16 @@ def run_slic(vid,flows,cfg): # print(scatter_flows_k[0,0,0,-3:,-3:]) # exit() + both = th.cat([scatter_weights[...,None],scatter_flows_k],-1) + # print(both.shape) + # print(both[0,0,0]) + # exit() + # -- topk -- scatter_flows_k = -scatter_flows_k s_weight,s_flows_k,s_labels = stnls.agg.scatter_topk(scatter_weights,scatter_flows_k, - scatter_labels,K0) + scatter_labels,K0, + descending=False) # print(s_flows_k.shape,s_labels.shape) # s_flows_k = s_flows_k.int() # print(th.any(s_weight<-1000).item()) @@ -216,17 +229,21 @@ def run_slic(vid,flows,cfg): # print(s_flows_k[0,0,:3]) # print(s_flows_k[0,0,100:103]) # print(s_flows_k[0,0,-3:]) + both = th.cat([s_weight[...,None],s_flows_k],-1) + # print(both.shape) + # print(both[0,0,:,:]) + # exit() pooled = slic_pooling(vid,s_weight,s_flows_k,s_labels, - cfg.ps,cfg.stride0,cfg.stride1,K0) + cfg.ps,cfg.stride0,cfg.stride1,K0,cfg.softmax_weight) # pooled = None - print(pooled.shape) + # print(pooled.shape) return pooled[:,:,:3],s_flows_k -def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0): +def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0,softmax_weight): # -- prepare weights and flows -- B,T,F,H,W = vid.shape @@ -236,12 +253,12 @@ def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0): s_labels = s_labels.reshape(B,HD,T*H*W,-1) # -- run scatters -- - print("pooling: ",s_weights.shape,s_flows_k.shape,s_labels.shape) + # print("pooling: ",s_weights.shape,s_flows_k.shape,s_labels.shape) weights = stnls.agg.scatter_tensor(s_weights,s_flows_k,s_labels, stride1,stride0,H,W) flows_k = stnls.agg.scatter_tensor(s_flows_k,s_flows_k,s_labels, stride1,stride0,H,W) - print(weights.shape,flows_k.shape) + # print(weights.shape,flows_k.shape) # -- reshape -- K = weights.shape[-1] @@ -252,22 +269,27 @@ def slic_pooling(vid,s_weights,s_flows_k,s_labels,ps,stride0,stride1,K0): # -- renormalize weights -- # print(weights) - weights = th.softmax(weights,-1) + weights = th.softmax(-softmax_weight*weights,-1) # print(weights) - - # # -- inspect -- - # print("scatter_weights.shape: ",weights.shape) - # args = th.where(th.isnan(weights[0,0])) - # print(args) - # exit() + # weights = weights / th.sum(weights,-1,keepdim=True) + # print(th.sum(weights,-1)) + # print(weights[0,0,:2,:2]) + # print(weights[0,0,-2:,-2:]) # -- aggregate -- - ps = int(stride0*1.75) + ps = stride0 # ps = stride0 - agg = stnls.agg.WeightedPatchSum(ps,stride0,itype="int") + # agg = stnls.agg.WeightedPatchSum(ps,stride0,itype="int") + # print(th.sum(weights,-1)) + # print(ps,stride0) + agg = stnls.agg.PooledPatchSum(ps,stride0,itype="int") + # vid = th.ones_like(vid) + # print("weights [min,max]: ",weights.min().item(),weights.max().item()) vout = agg(vid,weights,flows_k) vout = rearrange(vout,'b hd t c h w -> b t (hd c) h w') - # vout = None + # print("vin [min,max]: ",vid[...,:3,:,:].min().item(),vid[...,:3,:,:].max().item()) + # print("vout [min,max]: ",vout[...,:3,:,:].min().item(),vout[...,:3,:,:].max().item()) + # # vout = None # print("vout.shape,vid.shape: ",vout.shape,vid.shape) return vout @@ -314,7 +336,8 @@ def main(): cfg.seed = 123 cfg.dname = "set8" cfg.dset = "val" - cfg.isize = "540_540" + # cfg.isize = "540_540" + cfg.isize = "260_260" # cfg.isize = "256_256" # cfg.isize = "128_128" # cfg.isize = None @@ -322,19 +345,23 @@ def main(): # cfg.isize = "300_300" cfg.vid_name = "sunflower" cfg.ntype = "g" - cfg.sigma = 0.1 - cfg.nframes = 5 + cfg.sigma = 0.01 + cfg.nframes = 1 cfg.flow = False cfg.full_ws = True cfg.wt = 0 - cfg.stride0 = 8 + cfg.stride0 = 5 cfg.ws = 2*cfg.stride0-2 + # cfg.stride0 = 3 + # cfg.ws = 3 # if cfg.ws == 1: cfg.ws += 1 cfg.stride1 = 1 cfg.k = -1#cfg.ws*cfg.ws cfg.nls_k = 8 cfg.ps = 1 cfg.M = 0.1 + cfg.pooling_ksize = 1 + cfg.softmax_weight = 20. # -- run slic -- vid,pooled,segs = run_exp(cfg) @@ -362,8 +389,9 @@ def main(): vid_io.save_video(seg,"output/slic","ex") for ptype in pooled: - print(pooled[ptype].type,pooled[ptype].shape,pooled[ptype].max()) - vid_io.save_video(pooled[ptype][:,:,:3],"output/slic_pooled/",ptype) + print(pooled[ptype].type,pooled[ptype].shape,pooled[ptype][:,:,:3].max()) + vid = pooled[ptype][:,:,:3] + vid_io.save_video(vid,"output/slic_pooled/",ptype) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 8042f66..a9c811b 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,9 @@ 'lib/csrc/agg/wpsum.cpp', # Weighted Patch Sum 'lib/csrc/agg/wpsum_int_kernel.cu', 'lib/csrc/agg/wpsum_bilin2d_kernel.cu', + 'lib/csrc/agg/pool.cpp', # Pooled - Weighted Patch Sum + 'lib/csrc/agg/pool_int_kernel.cu', + #'lib/csrc/agg/pool_bilin2d_kernel.cu', # -- setup -- 'lib/csrc/pybind.cpp', ],