diff --git a/lib/csrc/nn/anchor_self.cpp b/lib/csrc/nn/anchor_self.cpp index dc4df07..dbcd2c2 100644 --- a/lib/csrc/nn/anchor_self.cpp +++ b/lib/csrc/nn/anchor_self.cpp @@ -34,6 +34,9 @@ void anchor_self_refine_forward_cuda( torch::Tensor dists, torch::Tensor inds, torch::Tensor flows, int stride0, int H, int W); +void anchor_self_paired_forward_cuda( + torch::Tensor dists, torch::Tensor inds, + torch::Tensor flows, int stride0, int H, int W); // C++ interface @@ -68,6 +71,15 @@ void anchor_self_refine_forward( anchor_self_refine_forward_cuda(dists,inds,flows,stride0,H,W); } +void anchor_self_paired_forward( + torch::Tensor dists, torch::Tensor inds, + torch::Tensor flows, int stride0, int H, int W){ + CHECK_INPUT(dists); + CHECK_INPUT(inds); + CHECK_INPUT(flows); + anchor_self_paired_forward_cuda(dists,inds,flows,stride0,H,W); +} + // python bindings void init_anchor_self(py::module &m){ m.def("anchor_self", &anchor_self_forward, @@ -76,5 +88,7 @@ void init_anchor_self(py::module &m){ "anchor_self (CUDA)"); m.def("anchor_self_refine", &anchor_self_refine_forward, "anchor_self (CUDA)"); + m.def("anchor_self_paired", &anchor_self_paired_forward, + "anchor_self (CUDA)"); } diff --git a/lib/csrc/nn/anchor_self_kernel.cu b/lib/csrc/nn/anchor_self_kernel.cu index 6984353..f6f7dad 100644 --- a/lib/csrc/nn/anchor_self_kernel.cu +++ b/lib/csrc/nn/anchor_self_kernel.cu @@ -224,14 +224,14 @@ __global__ void anchor_self_time_kernel( if (qi >= Q){ continue; } // -- unpack pixel locs -- - // get_pixel_loc(loc, qi, tmp, stride0, nW, nHW, H,W); - int tmp = qi; - iloc[0] = qi / nHW; - tmp = (tmp - iloc[0]*nHW); - int nH_index = tmp / nW; - iloc[1] = (nH_index*stride0) % H; - tmp = tmp - nH_index*nW; - iloc[2] = ((tmp % nW) * stride0) % W; + get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W); + // int tmp = qi; + // iloc[0] = qi / nHW; + // tmp = (tmp - iloc[0]*nHW); + // int nH_index = tmp / nW; + // iloc[1] = (nH_index*stride0) % H; + // tmp = tmp - nH_index*nW; + // iloc[2] = ((tmp % nW) * stride0) % W; // -- select time -- int n_hi = iloc[1]/stride0; @@ -241,7 +241,6 @@ __global__ void anchor_self_time_kernel( int t_next = iloc[0] + st_i; t_next = (t_next > t_max) ? t_max - st_i : t_next; - // -- get anchor index -- loc[0] = t_next - iloc[0]; if (st_i >= st_offset){ @@ -366,6 +365,175 @@ void anchor_self_time_forward_cuda( } +/********************************************************* + + + Anchor Paired Search + + +*********************************************************/ + + +template +__global__ void anchor_self_paired_kernel( + torch::PackedTensorAccessor32 dists, + torch::PackedTensorAccessor32 inds, + torch::PackedTensorAccessor32 flows, + int stride0, int H, int W, int nHW, int nW, int q_per_thread){ + + // -- starting qi for thread -- + int HD = dists.size(1); + int HD_f = flows.size(1); + int Q = dists.size(2); + int W_t = dists.size(3); + int K = dists.size(4); + int T = flows.size(2); + int bi = blockIdx.y; + int hi = blockIdx.z; + int hi_f = hi % HD_f; + int raster_idx = threadIdx.x + blockDim.x * blockIdx.x; + int qi_thread = raster_idx/W_t; + int gi = (raster_idx - qi_thread*W_t); + qi_thread = q_per_thread*qi_thread; + int self_index = 0; + bool eq_loc; + int iloc[3]; + itype loc[3]; + itype i_tmp[3]; + scalar_t d_tmp; + int qi; + scalar_t delta,dmin_curr; + int min_idx; + + // -- for each location -- + if (gi >= W_t ){ return; } + for (int qi_ix = 0; qi_ix < q_per_thread; qi_ix++){ + + // -- current query -- + qi = qi_thread + qi_ix; + if (qi >= Q){ continue; } + + // -- unpack pixel locs -- + get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W); + int n_hi = iloc[1]/stride0; + int n_wi = iloc[2]/stride0; + + // -- get anchor index -- + auto flows_t = flows[bi][hi_f][gi]; + loc[0] = iloc[1] + flows_t[1][n_hi][n_wi]; + loc[1] = iloc[2] + flows_t[0][n_hi][n_wi]; + loc[0] = bounds(loc[0],H)-iloc[1]; + loc[1] = bounds(loc[1],W)-iloc[2]; + + // -- search for matching index -- + min_idx = 0; + dmin_curr = 10000; + for (self_index = 0; self_index < K; self_index++){ + + delta = 0; + eq_loc = true; + for (int ix=0; ix<2; ix++){ + if (is_same_v){ + eq_loc = eq_loc && (inds[bi][hi][qi][gi][self_index][ix] == loc[ix]); + }else{ + delta += fabs(inds[bi][hi][qi][gi][self_index][ix] - loc[ix]); + } + } + eq_loc = eq_loc && (delta < 1e-8); + + if (is_same_v){ + if (eq_loc){ min_idx = self_index; break; } + }else{ + if (delta < 1e-8){ min_idx = self_index; break; }// break if equal + else if (delta < dmin_curr){ // update min otherwise + min_idx = self_index; + dmin_curr = delta; + } + } + + } + assert(min_idx<<>>( + dists.packed_accessor32(), + inds.packed_accessor32(), + flows.packed_accessor32(), + stride0, H, W, nHW, nW, q_per_thread); + })); + }else if (itype == dtype){ + AT_DISPATCH_FLOATING_TYPES(dists.type(), "anchor_self_paired_kernel", ([&] { + anchor_self_paired_kernel<<>>( + dists.packed_accessor32(), + inds.packed_accessor32(), + flows.packed_accessor32(), + stride0, H, W, nHW, nW, q_per_thread); + })); + + }else{ + std::cout << "Must have inds type be int or match dists.\n" << std::endl; + assert(1==0); + } + +} + + /********************************************************* @@ -397,6 +565,7 @@ __global__ void anchor_self_refine_kernel( qi_thread = q_per_thread*qi_thread; int self_index = 0; bool eq_loc; + int iloc[3]; itype loc[3]; itype i_tmp[3]; scalar_t d_tmp; @@ -412,10 +581,20 @@ __global__ void anchor_self_refine_kernel( qi = qi_thread + qi_ix; if (qi >= Q){ continue; } - // -- unpack pixel locs -- + // // -- unpack pixel locs -- + // loc[0] = round(flows[bi][hi_f][qi][gi][0]); + // loc[1] = flows[bi][hi_f][qi][gi][1]; + // loc[2] = flows[bi][hi_f][qi][gi][2]; + get_pixel_loc(iloc, qi, stride0, nW, nHW, H, W); + // int n_hi = iloc[1]/stride0; + // int n_wi = iloc[2]/stride0; + + // -- get anchor index -- loc[0] = round(flows[bi][hi_f][qi][gi][0]); - loc[1] = flows[bi][hi_f][qi][gi][1]; - loc[2] = flows[bi][hi_f][qi][gi][2]; + loc[1] = iloc[1] + flows[bi][hi_f][qi][gi][1]; + loc[2] = iloc[2] + flows[bi][hi_f][qi][gi][2]; + loc[1] = bounds(loc[1],H)-iloc[1]; + loc[2] = bounds(loc[2],W)-iloc[2]; // -- search for matching index -- min_idx = 0; diff --git a/lib/csrc/search/paired_search_kernel.cu b/lib/csrc/search/paired_search_kernel.cu index b593578..ad6f1ba 100644 --- a/lib/csrc/search/paired_search_kernel.cu +++ b/lib/csrc/search/paired_search_kernel.cu @@ -20,7 +20,7 @@ template __global__ void paired_search_int_forward_kernel( const torch::PackedTensorAccessor32 frame0, const torch::PackedTensorAccessor32 frame1, - const torch::PackedTensorAccessor32 flow, + const torch::PackedTensorAccessor32 flow, torch::PackedTensorAccessor32 dists, torch::PackedTensorAccessor32 inds, int ws, int ps, int stride0, int stride1, int dilation, @@ -193,7 +193,7 @@ void paired_search_int_forward_cuda( paired_search_int_forward_kernel<<>>( frame0.packed_accessor32(), frame1.packed_accessor32(), - flow.packed_accessor32(), + flow.packed_accessor32(), dists.packed_accessor32(), inds.packed_accessor32(), ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, @@ -204,7 +204,7 @@ void paired_search_int_forward_cuda( paired_search_int_forward_kernel<<>>( frame0.packed_accessor32(), frame1.packed_accessor32(), - flow.packed_accessor32(), + flow.packed_accessor32(), dists.packed_accessor32(), inds.packed_accessor32(), ws, ps, stride0, stride1, dilation, reflect_bounds, full_ws, diff --git a/lib/stnls/nn/__init__.py b/lib/stnls/nn/__init__.py index 82cb7d7..37c022a 100644 --- a/lib/stnls/nn/__init__.py +++ b/lib/stnls/nn/__init__.py @@ -15,6 +15,7 @@ anchor_self = anchor_self_f.run anchor_self_time = anchor_self_f.run_time anchor_self_refine = anchor_self_f.run_refine +anchor_self_paired = anchor_self_f.run_paired non_local_inds = non_local_inds_f.run accumulate_flow = accumulate_flow_f.run extract_search_from_accumulated = accumulate_flow_f.extract_search_from_accumulated diff --git a/lib/stnls/nn/anchor_self.py b/lib/stnls/nn/anchor_self.py index daa9c53..7ae0f7a 100644 --- a/lib/stnls/nn/anchor_self.py +++ b/lib/stnls/nn/anchor_self.py @@ -49,11 +49,19 @@ def run(dists,inds,stride0,H,W,qstart=0): def run_refine(dists,inds,flows,stride0,H,W): # -- view -- - B,HD,T,nH,nW,Ks,ws,ws = dists.shape - dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws) - inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3) HD_f = flows.shape[1] - flows = flows.view(B,HD_f,T*nH*nW,Ks,3) + if dists.ndim == 8: + B,HD,T,nH,nW,Ks,ws,ws = dists.shape + dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws) + inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3) + flows = flows.view(B,HD_f,T*nH*nW,Ks,3) + assert inds.shape[-1] == 3,"Index must be size 3." + # elif dists.ndim == 6: + # B,HD,T,nH,nW,Ks,ws,ws = dists.shape + # dists = dists.view(B,HD,T*nH*nW,Ks,ws*ws) + # inds = inds.view(B,HD,T*nH*nW,Ks,ws*ws,3) + # flows = flows.view(B,HD_f,T*nH*nW,Ks,3) + # print("dists.shape,inds.shape,flows.shape: ",dists.shape,inds.shape,flows.shape) # -- run -- stnls_cuda.anchor_self_refine(dists,inds,flows,stride0,H,W) @@ -65,7 +73,21 @@ def run_time(dists,inds,flows,wt,stride0,H,W): d2or3 = inds.shape[-1] dists = dists.view(B,HD,Q,W_t,ws*ws) inds = inds.view(B,HD,Q,W_t,ws*ws,d2or3) + assert d2or3 == 3,"Index must be size 3." # -- run -- stnls_cuda.anchor_self_time(dists,inds,flows,wt,stride0,H,W) +def run_paired(dists,inds,flows,stride0,H,W): + + # -- view -- + B,HD,Q,G,ws,ws = dists.shape + d2or3 = inds.shape[-1] + dists = dists.view(B,HD,Q,G,ws*ws) + inds = inds.view(B,HD,Q,G,ws*ws,d2or3) + assert d2or3 == 2,"Index must be size 2." + + # -- run -- + stnls_cuda.anchor_self_paired(dists,inds,flows,stride0,H,W) + + diff --git a/lib/stnls/search/batching_utils.py b/lib/stnls/search/batching_utils.py deleted file mode 100644 index d2a6fa1..0000000 --- a/lib/stnls/search/batching_utils.py +++ /dev/null @@ -1,48 +0,0 @@ - -import torch as th - -def run_batched(run_fxn,batchsize,vid_idx,stride0_idx,ws_idx,wt_idx,*args): - dists,inds = [],[] - ntotal,nbatches,batchsize = batching_info(args[vid_idx],args[stride0_idx], - args[ws_idx],args[wt_idx], - batchsize) - for nbatch in range(nbatches): - qshift = nbatch*batchsize - nqueries = min(ntotal-qshift,batchsize) - # print(nbatches,batch,qshift,nqueries,ntotal) - assert nqueries > 0 - dists_b,inds_b = run_fxn(qshift,nqueries,*args) - dists.append(dists_b) - inds.append(inds_b) - dists = th.cat(dists,2) - inds = th.cat(inds,2) - return dists,inds - -def batching_info(vid,stride0,ws,wt,batchsize): - - # -- compute num refs -- - B,HD,T,C,H,W = vid.shape - nH = (H-1)//stride0+1 - nW = (W-1)//stride0+1 - ntotal = T * nH * nW - - # -- recompute batch size w.r.t max size -- - batchsize = get_max_batchsize(batchsize,ntotal,ws,wt) - - nbatches = (ntotal-1)//batchsize+1 - return ntotal,nbatches,batchsize - -def get_max_batchsize(batchsize,nrefs,ws,wt): - # ntotal_locs = nrefs * nsearch - # ntotal_ints = ntotal_locs*ntotal_search - st = 2 * wt + 1 - nsearch = ws * ws * st - # nmax = 2**31-1 - nmax = 2**22 - max_nrefs = int(nmax / (nsearch*3)) - # print(batchsize,max_nrefs,nrefs,ws,wt,nsearch) - # print(batchsize,max_nrefs,nrefs) - if batchsize <= 0: - batchsize = min(max_nrefs,nrefs) - batchsize = min(batchsize,min(max_nrefs,nrefs)) - return batchsize diff --git a/lib/stnls/search/n3mm_utils.py b/lib/stnls/search/impl/n3mm_utils.py similarity index 98% rename from lib/stnls/search/n3mm_utils.py rename to lib/stnls/search/impl/n3mm_utils.py index c65cfa1..0998575 100644 --- a/lib/stnls/search/n3mm_utils.py +++ b/lib/stnls/search/impl/n3mm_utils.py @@ -3,7 +3,7 @@ import numpy as np from einops import rearrange import stnls_cuda -from .shared import run_unfold +from ..shared import run_unfold # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # diff --git a/lib/stnls/search/impl/non_local_search.py b/lib/stnls/search/impl/non_local_search.py new file mode 100644 index 0000000..b950ec4 --- /dev/null +++ b/lib/stnls/search/impl/non_local_search.py @@ -0,0 +1,173 @@ +# -- python -- +import torch as th +import numpy as np +from einops import rearrange + +# -- cpp cuda kernel -- +import stnls_cuda + +# -- package -- +import stnls + +# -- api -- +from stnls.utils import extract_pairs + +# -- forward utils -- +from stnls.search.utils import allocate_pair,dist_type_select,allocate_vid + +# -- backward utils -- +from stnls.search.utils import get_inds,allocate_grad_flows +from stnls.search.shared import normz_bwd + + +def forward(vid0, vid1, flows, + ws, wt, ps, k, stride0, stride1, + dist_type, dilation, pt, + topk_mode, self_action, + reflect_bounds, full_ws, use_adj, itype): + + # -- unpack -- + # itype = "int" + device = vid0.device + B,HD,T,C,H,W = vid0.shape + patch_offset = 0 if use_adj else -(ps//2) + # print(ps,k,dist_type,topk_mode,self_action,patch_offset) + + # -- derived shapes -- + nH0 = (H-1)//stride0+1 + nW0 = (W-1)//stride0+1 + Q = T*nH0*nW0 + # print(vid0.shape,nH0,nW0,Q) + + # -- settings from distance type -- + dist_type_i,descending,idist_val = dist_type_select(dist_type) + + # -- allocate results -- + W_t = min(2*wt+1,T) + base_shape = (B,HD,Q,W_t,ws,ws) + dists,inds = allocate_pair(base_shape,device,vid0.dtype,idist_val,itype) + + # -- check flows -- + assert flows.shape[3] in [W_t-1,W_t] + + # -- forward -- + if itype == "int": + if flows.dtype != th.int: + flows = flows.round().int() + else: + flows = flows.int() + inds = inds.int() + stride1 = max(1,int(stride1)) + fwd_fxn = stnls_cuda.non_local_search_int_forward + else: + fwd_fxn = stnls_cuda.non_local_search_bilin2d_forward + stride1 = float(stride1) + fwd_fxn(vid0, vid1, flows, dists, inds, + ps, k, stride0, stride1, dilation, pt, + reflect_bounds, full_ws, patch_offset, dist_type_i) + + # -- anchor -- + assert self_action in [None,"anchor","anchor_each","remove","remove_ref_frame"] + anchor_self = False if self_action is None else "anchor" in self_action + if self_action == "anchor": + stnls.nn.anchor_self(dists,inds,stride0,H,W) + elif self_action == "anchor_each": + stnls.nn.anchor_self_time(dists,inds,flows,wt,stride0,H,W) + elif self_action == "remove": + raise NotImplementedError("Not implemented self_action [remove].") + elif self_action == "remove_ref_frame": + assert wt > 0,"Cannot remove ref frame if not searching across time." + dists = dists[...,1:,:,:].contiguous() + inds = inds[...,1:,:,:,:].contiguous() + elif self_action is None: + pass + else: + raise ValueError(f"Uknown option for self_action [{self_action}]") + + # -- topk -- + if topk_mode == "all": + dim = 3 + dists=dists.view(B,HD,Q,W_t*ws*ws) + inds=inds.view(B,HD,Q,W_t*ws*ws,3) + dists,inds = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, + descending=descending) + elif topk_mode == "each": + dists = rearrange(dists,'... wh ww -> ... (wh ww)') + inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3') + dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self) + else: + raise ValueError(f"Unknown topk_mode [{topk_mode}]") + + # -- reshape -- + dists=dists.view(B,HD,T,nH0,nW0,-1) + inds=inds.view(B,HD,T,nH0,nW0,-1,3) + + return dists,inds + +def backward(ctx, grad_dists, grad_inds): + + # -- populate names -- + dists,inds,vid0,vid1,flows = ctx.saved_tensors + itype_bwd = ctx.itype + + # -- allocate grads -- + grad_vid0 = allocate_vid(ctx.vid_shape,grad_dists.device) + grad_vid1 = allocate_vid(ctx.vid_shape,grad_dists.device) + grad_flows = allocate_grad_flows(itype_bwd,flows.shape,flows.device) + + # -- restrict to k_agg; the number of neighbors used which will prop gradient -- + if ctx.k_agg > 0: + grad_dists = grad_dists[...,:ctx.k_agg].contiguous() + grad_inds = grad_inds[...,:ctx.k_agg].contiguous() + dists = dists[...,:ctx.k_agg].contiguous() + inds = inds[...,:ctx.k_agg,:] + dists = dists.contiguous() + inds = inds.contiguous() + + # -- ensure contiguous & type -- + inds = get_inds(inds,ctx.itype) + patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) + reflect_bounds = ctx.reflect_bounds + + # -- backward pass with increasing complexity -- + if ctx.itype == "int": + bwd_fxn = stnls_cuda.non_local_search_int_vid_backward + bwd_fxn(grad_vid0,grad_vid1, + vid0,vid1,grad_dists,inds, + ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + reflect_bounds,patch_offset,ctx.dist_type_i) + elif not(flows.requires_grad): + bwd_fxn = stnls_cuda.non_local_search_bilin2d_vid_backward + bwd_fxn(grad_vid0,grad_vid1,vid0,vid1, + grad_dists,inds, + ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + reflect_bounds,patch_offset,ctx.dist_type_i) + else: + bwd_fxn = stnls_cuda.non_local_search_bilin2d_vidflows_backward + bwd_fxn(grad_vid0,grad_vid1,grad_flows, + vid0,vid1,flows, + grad_dists,grad_inds,dists,inds, + ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + reflect_bounds,patch_offset,ctx.dist_type_i) + + # -- finalize shape -- + grad_vid0 = rearrange(grad_vid0,'B H t c h w -> B t (H c) h w') + grad_vid1 = rearrange(grad_vid1,'B H t c h w -> B t (H c) h w') + + # -- normz -- + if ctx.normalize_bwd: + normz_bwd(ctx,grad_vid0,grad_vid1) + + # -- no grad if ints -- + if itype_bwd == "int" or not(flows.requires_grad): + grad_flows = None + if ctx.flow_ndim == 6 and flows.requires_grad: + grad_flows = grad_flows[:,0].contiguous() + # print(grad_flows.shape) + # print(th.where(flows[0,0]!=0)) + # print(th.where(grad_flows[0,0]!=0)) + # print("-"*20) + # print(th.all(grad_flows==0).item()) + + return grad_vid0,grad_vid1,grad_flows + diff --git a/lib/stnls/search/impl/paired_search.py b/lib/stnls/search/impl/paired_search.py new file mode 100644 index 0000000..6c1a1ad --- /dev/null +++ b/lib/stnls/search/impl/paired_search.py @@ -0,0 +1,214 @@ + +# -- python -- +import torch as th +import numpy as np +from einops import rearrange + +# -- cpp cuda kernel -- +import stnls_cuda + +# -- package -- +import stnls + +# -- local -- +from ..utils import shape_frames,allocate_pair_2d,dist_type_select,allocate_vid +from ..utils import get_ctx_shell,ensure_flow_shape +from ..utils import ensure_paired_flow_dim as ensure_flow_dim +from ..shared import manage_self,reflect_bounds_warning +from ..utils import paired_vids as _paired_vids +# from ..paired_utils import paired_vids as _paired_vids +from ..utils import shape_vids,allocate_pair,dist_type_select,allocate_vid +from ..utils import get_inds,allocate_grad_flows +from ..shared import manage_self + +def forward(frame0, frame1, flow, + ws, ps, k, dist_type, + stride0, stride1, dilation, pt, + self_action, reflect_bounds, + full_ws, use_adj, itype): + + # -- unpack -- + device = frame0.device + B,HD_fr,C,H,W = frame0.shape + HD_flow = flow.shape[1] + # print(frame0.shape,flow.shape) + assert flow.ndim == 5 + HD = max(HD_flow,HD_fr) + patch_offset = 0 if use_adj else -(ps//2) + + # -- derived shapes -- + nH0 = (H-1)//stride0+1 + nW0 = (W-1)//stride0+1 + Q = nH0*nW0 + + # -- settings from distance type -- + dist_type_i,descending,idist_val = dist_type_select(dist_type) + + # -- allocate results -- + base_shape = (B,HD,Q,ws,ws) + dists,inds = allocate_pair_2d(base_shape,device,frame0.dtype,idist_val,itype) + # print("inds.shape: ",inds.shape) + + # -- forward -- + if itype == "int": + flow = flow.round().int() + inds = inds.int() + stride1 = max(1,int(stride1)) + fwd_fxn = stnls_cuda.paired_search_int_forward + else: + fwd_fxn = stnls_cuda.paired_search_bilin2d_forward + stride1 = float(stride1) + fwd_fxn(frame0, frame1, flow, dists, inds, + ps, k, stride0, stride1, dilation, + reflect_bounds, full_ws, patch_offset, dist_type_i) + + # -- anchor -- + assert self_action in [None,"anchor","anchor_each"] + anchor_self = False if self_action is None else "anchor" in self_action + if self_action is None: pass + elif "anchor" in self_action: + dists,inds = dists[...,None,:,:],inds[...,None,:,:,:] + flow = flow[:,:,None] + stnls.nn.anchor_self_paired(dists,inds,flow,stride0,H,W) + else: + raise ValueError(f"Uknown option for self_action [{self_action}]") + + + # -- topk -- + if k > 0: + dim = 3 + dists=dists.view(B,HD,Q,ws*ws) + inds=inds.view(B,HD,Q,ws*ws,2) + dists,inds = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, + descending=descending) + + # -- reshape -- + dists=dists.reshape(B,HD,1,nH0,nW0,-1) + inds=inds.reshape(B,HD,1,nH0,nW0,-1,2) + + return dists,inds + + + +def backward(ctx, grad_dists, grad_inds): + + # -- populate names -- + inds,frame0,frame1,flow = ctx.saved_tensors + itype_bwd = ctx.itype + inds = get_inds(inds,itype_bwd) + grad_flow = allocate_grad_flows(itype_bwd,flow.shape,flow.device) + + # -- allocate grads -- + grad_frame0 = allocate_vid(ctx.vid_shape,grad_dists.device) + grad_frame1 = allocate_vid(ctx.vid_shape,grad_dists.device) + # return grad_vid0,grad_vid1,grad_flow + + # -- restrict to k_agg -- + if ctx.k_agg > 0: + grad_dists = grad_dists[...,:ctx.k_agg] + inds = inds[...,:ctx.k_agg,:] + + # -- ensure contiguous -- + grad_dists = grad_dists.contiguous() + inds = inds.contiguous() + + # -- view -- + B,HD,T,nH,nW,K,_ = inds.shape + inds = inds.view(B,HD,T*nH*nW,K,2) + grad_inds = grad_inds.view(B,HD,T*nH*nW,K,2) + grad_dists = grad_dists.view(B,HD,T*nH*nW,K) + patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) + + # -- allow for repeated exec -- + grad_inds = grad_inds.contiguous() + if itype_bwd == "int": + bwd_fxn = stnls_cuda.paired_search_int_backward + inds = inds.view(B,HD,T*nH*nW,K,2) + bwd_fxn(grad_frame0,grad_frame1, + frame0,frame1,grad_dists,inds, + ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds, + patch_offset,ctx.dist_type_i) + else: + bwd_fxn = stnls_cuda.paired_search_bilin2d_backward + bwd_fxn(grad_frame0,grad_frame1,grad_flow, + frame0,frame1,flow,grad_dists,grad_inds,inds, + ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds, + patch_offset,ctx.dist_type_i) + + # -- finalize shape -- + if ctx.in_ndim == 4: + grad_frame0 = rearrange(grad_frame0,'B H c h w -> B (H c) h w') + grad_frame1 = rearrange(grad_frame1,'B H c h w -> B (H c) h w') + + # -- normz -- + if ctx.normalize_bwd: + normz_bwd(ctx,grad_frame0,grad_frame1) + + # -- no grad if ints -- + if itype_bwd == "int": + grad_flow = None + + return grad_frame0,grad_frame1,grad_flow + + +def paired_refine_backward(ctx, grad_dists, grad_inds): + + # -- populate names -- + inds,frame0,frame1,flow = ctx.saved_tensors + itype_bwd = ctx.itype + inds = get_inds(inds,itype_bwd) + grad_flow = allocate_grad_flows(itype_bwd,flow.shape,flow.device) + + # -- allocate grads -- + grad_frame0 = allocate_vid(ctx.vid_shape,grad_dists.device) + grad_frame1 = allocate_vid(ctx.vid_shape,grad_dists.device) + # return grad_vid0,grad_vid1,grad_flow + + # -- restrict to k_agg -- + if ctx.k_agg > 0: + grad_dists = grad_dists[...,:ctx.k_agg] + inds = inds[...,:ctx.k_agg,:] + + # -- ensure contiguous -- + grad_dists = grad_dists.contiguous() + inds = inds.contiguous() + + # -- view -- + B,HD,T,nH,nW,K,_ = inds.shape + inds = inds.view(B,HD,T*nH*nW,K,2) + grad_inds = grad_inds.view(B,HD,T*nH*nW,K,2) + grad_dists = grad_dists.view(B,HD,T*nH*nW,K) + patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) + + # -- allow for repeated exec -- + grad_inds = grad_inds.contiguous() + if itype_bwd == "int": + bwd_fxn = stnls_cuda.paired_refine_int_backward + inds = inds.view(B,HD,T*nH*nW,K,2) + bwd_fxn(grad_frame0,grad_frame1, + frame0,frame1,grad_dists,inds, + ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds, + patch_offset,ctx.dist_type_i) + else: + bwd_fxn = stnls_cuda.paired_refine_bilin2d_backward + bwd_fxn(grad_frame0,grad_frame1,grad_flow, + frame0,frame1,flow,grad_dists,grad_inds,inds, + ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds, + patch_offset,ctx.dist_type_i) + + # -- finalize shape -- + if ctx.in_ndim == 4: + grad_frame0 = rearrange(grad_frame0,'B H c h w -> B (H c) h w') + grad_frame1 = rearrange(grad_frame1,'B H c h w -> B (H c) h w') + + # -- normz -- + if ctx.normalize_bwd: + normz_bwd(ctx,grad_frame0,grad_frame1) + + # -- no grad if ints -- + if itype_bwd == "int": + grad_flow = None + + return grad_frame0,grad_frame1,grad_flow + + diff --git a/lib/stnls/search/impl/refinement.py b/lib/stnls/search/impl/refinement.py new file mode 100644 index 0000000..d5f5399 --- /dev/null +++ b/lib/stnls/search/impl/refinement.py @@ -0,0 +1,170 @@ + + +# -- python -- +import torch as th +import numpy as np +from einops import rearrange + +# -- cpp cuda kernel -- +import stnls_cuda + +# -- package -- +import stnls + +# -- local -- +from stnls.search.utils import allocate_pair,dist_type_select,allocate_vid +from stnls.search.utils import get_inds,allocate_grad_flows +from stnls.search.shared import normz_bwd + +def forward(vid0, vid1, flows, + ws, wr, k, kr, ps, stride0, stride1, dilation, pt, + dist_type, restricted_radius, + reflect_bounds, full_ws, + topk_mode, self_action, patch_offset, itype_fwd): + + # -- fix negative Q -- + # if Q > 0: + # flows = flows[:,:,qshift:qshift+Q].contiguous() + B,HD,T,nH,nW,Ks,_ = flows.shape + Q = T*nH*nW + + # -- settings from distance type -- + dist_type_i,descending,idist_val = dist_type_select(dist_type) + + # -- allocate results -- + device = flows.device + B,HD,T,nH,nW,Ks = flows.shape[:-1] + base_shape = (B,HD,T,nH,nW,Ks,wr,wr) + # print(base_shape,flows.shape) + dists,inds = allocate_pair(base_shape,device,vid0.dtype,idist_val,itype_fwd) + + # -- allow for int fwd when actually float -- + if itype_fwd == "int": + inds = inds.int() + if flows.dtype == th.float: + flows = flows.round().int() + kselect = th.zeros(0,device=flows.device) + reflect = th.zeros(0,device=flows.device) + else: + kselect = th.ones_like(dists).int() + reflect = th.zeros_like(flows[...,:2]).bool() + + # -- run -- + if itype_fwd == "int": + stride1 = int(max(1,int(stride1))) + fwd_fxn = stnls_cuda.refinement_int_forward + fwd_fxn(vid0, vid1, flows, dists, inds, + ws, ps, stride0, stride1, dilation, pt, + restricted_radius, reflect_bounds, full_ws, + patch_offset, dist_type_i) + else: + stride1 = float(stride1) + fwd_fxn = stnls_cuda.refinement_bilin2d_forward + fwd_fxn(vid0, vid1, flows, dists, inds, + kselect, reflect, + ws, ps, stride0, stride1, dilation, pt, + restricted_radius, reflect_bounds, full_ws, + patch_offset, dist_type_i) + + # -- manage self dists -- + if not(self_action is None) and "anchor" in self_action: + H,W = vid0.shape[-2:] + stnls.nn.anchor_self_refine(dists,inds,flows,stride0,H,W) + else: + assert self_action == None + + # -- topk -- + assert self_action in [None,"anchor","anchor_each"] + anchor_self = False if self_action is None else "anchor" in self_action + if topk_mode == "all": + dim = 3 + dists=dists.view(B,HD,Q,Ks*wr*wr) + inds=inds.view(B,HD,Q,Ks*wr*wr,3) + dists,inds,order = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, + descending=descending,unique=False, + return_order=True) + if kselect.ndim > 1: + # print("kselect.shape: ",kselect.shape,order.shape) + kselect = kselect.view(B,HD,Q,Ks*wr*wr) if not(kselect is None) else kselect + # print("kselect.shape: ",kselect.shape,order.shape) + kselect = stnls.nn.topk_f.apply_topk(kselect,order,dim) + elif topk_mode == "each": + # print(dists.shape,kselect.shape) + dists = rearrange(dists,'... wh ww -> ... (wh ww)') + inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3') + dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self) + if kselect.ndim > 1: + kselect = rearrange(kselect,'... wh ww -> ... (wh ww)') + kselect = kselect[...,:k] # all same across dim + else: + raise ValueError(f"Unknown topk_mode [{topk_mode}]") + + + # -- reshape for output -- + dists=dists.view(B,HD,T,nH,nW,-1) + inds=inds.view(B,HD,T,nH,nW,-1,3) + kselect = kselect.view(B,HD,T,nH,nW,-1) if not(kselect is None) else kselect + # print("kselect.shape,reflect.shape: ",kselect.shape,reflect.shape) + # print(flows.shape,inds.shape,kselect.shape) + # print(th.cat([flows[0,0,...,[0]],inds[0,0,...,[0]],kselect[0,0,...,None]],-1)) + + return dists,inds,kselect,reflect + +def backward(ctx, grad_dists, grad_inds): + + # -- populate names -- + inds,vid0,vid1,kselect,reflect = ctx.saved_tensors + itype_bwd = ctx.itype_bwd + device = grad_dists.device + + # -- allocate grads -- + grad_vid0 = allocate_vid(ctx.vid_shape,device) + grad_vid1 = allocate_vid(ctx.vid_shape,device) + grad_flows = allocate_grad_flows(itype_bwd,ctx.flows_shape,device) + + # -- restrict to k_agg -- + if ctx.k_agg > 0: + grad_dists = grad_dists[...,:ctx.k_agg] + inds = inds[...,:ctx.k_agg,:] + + # -- ensure contiguous -- + grad_dists = grad_dists.contiguous() + inds = get_inds(inds,itype_bwd) + patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) + + # -- backward pass with increasing complexity -- + # print(inds[...,1:].min().item(),inds[...,1:].max().item()) + # print("gvid0: ",grad_vid0.shape) + # print(kselect.min().item(),kselect.max().item()) + if itype_bwd == "int": + bwd_fxn = stnls_cuda.non_local_search_int_vid_backward + bwd_fxn(grad_vid0,grad_vid1,vid0,vid1,grad_dists,inds, + ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + ctx.reflect_bounds,patch_offset,ctx.dist_type_i) + elif not(ctx.flows_requires_grad): + bwd_fxn = stnls_cuda.non_local_search_bilin2d_vid_backward + bwd_fxn(grad_vid0,grad_vid1,vid0,vid1,grad_dists,inds, + ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + ctx.reflect_bounds,patch_offset,ctx.dist_type_i) + else: + bwd_fxn = stnls_cuda.refinement_bilin2d_vidflows_backward + bwd_fxn(grad_vid0,grad_vid1,grad_flows, + vid0,vid1,grad_dists,grad_inds,inds, + kselect,reflect, + ctx.ws,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, + ctx.reflect_bounds,patch_offset,ctx.dist_type_i) + th.cuda.synchronize() + + # -- finalize shape -- + grad_vid0 = rearrange(grad_vid0,'B H t c h w -> B t (H c) h w') + grad_vid1 = rearrange(grad_vid1,'B H t c h w -> B t (H c) h w') + + # -- normz -- + if ctx.normalize_bwd: + normz_bwd(ctx,grad_vid0,grad_vid1) + + # -- no grad if ints -- + if itype_bwd == "int": grad_flows = None + + return grad_vid0,grad_vid1,grad_flows + diff --git a/lib/stnls/search/n3mm_search.py b/lib/stnls/search/n3mm_search.py index 85e1d3b..3c15b02 100644 --- a/lib/stnls/search/n3mm_search.py +++ b/lib/stnls/search/n3mm_search.py @@ -17,10 +17,7 @@ from .utils import shape_vids,allocate_inds,dist_type_select,allocate_vid from .utils import descending_menu from .shared import manage_self,run_fold -# from .nls_bwd_impl import nls_backward -# from .batching_utils import run_batched,batching_info -# from .n3mm_utils import IndexedMatmul1Efficient -from .n3mm_utils import matmult_fwd,matmult_bwd,raster_indices,vid2patches +from .impl.n3mm_utils import matmult_fwd,matmult_bwd,raster_indices,vid2patches # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # diff --git a/lib/stnls/search/nls_bwd_impl.py b/lib/stnls/search/nls_bwd_impl.py deleted file mode 100644 index 529c159..0000000 --- a/lib/stnls/search/nls_bwd_impl.py +++ /dev/null @@ -1,83 +0,0 @@ - -# -- python -- -import torch as th -import numpy as np -from einops import rearrange - -# -- cpp cuda kernel -- -import stnls_cuda - -# -- package -- -import stnls - -# -- local -- -from .utils import shape_vids,allocate_pair,dist_type_select,allocate_vid -from .utils import get_inds,allocate_grad_flows -from .shared import manage_self,normz_bwd - -def nls_backward(ctx, grad_dists, grad_inds): - - # -- populate names -- - dists,inds,vid0,vid1,flows = ctx.saved_tensors - itype_bwd = ctx.itype - - # -- allocate grads -- - grad_vid0 = allocate_vid(ctx.vid_shape,grad_dists.device) - grad_vid1 = allocate_vid(ctx.vid_shape,grad_dists.device) - grad_flows = allocate_grad_flows(itype_bwd,flows.shape,flows.device) - - # -- restrict to k_agg; the number of neighbors used which will prop gradient -- - if ctx.k_agg > 0: - grad_dists = grad_dists[...,:ctx.k_agg].contiguous() - grad_inds = grad_inds[...,:ctx.k_agg].contiguous() - dists = dists[...,:ctx.k_agg].contiguous() - inds = inds[...,:ctx.k_agg,:] - dists = dists.contiguous() - inds = inds.contiguous() - - # -- ensure contiguous & type -- - inds = get_inds(inds,ctx.itype) - patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) - reflect_bounds = ctx.reflect_bounds - - # -- backward pass with increasing complexity -- - if ctx.itype == "int": - bwd_fxn = stnls_cuda.non_local_search_int_vid_backward - bwd_fxn(grad_vid0,grad_vid1, - vid0,vid1,grad_dists,inds, - ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - reflect_bounds,patch_offset,ctx.dist_type_i) - elif not(flows.requires_grad): - bwd_fxn = stnls_cuda.non_local_search_bilin2d_vid_backward - bwd_fxn(grad_vid0,grad_vid1,vid0,vid1, - grad_dists,inds, - ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - reflect_bounds,patch_offset,ctx.dist_type_i) - else: - bwd_fxn = stnls_cuda.non_local_search_bilin2d_vidflows_backward - bwd_fxn(grad_vid0,grad_vid1,grad_flows, - vid0,vid1,flows, - grad_dists,grad_inds,dists,inds, - ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - reflect_bounds,patch_offset,ctx.dist_type_i) - - # -- finalize shape -- - grad_vid0 = rearrange(grad_vid0,'B H t c h w -> B t (H c) h w') - grad_vid1 = rearrange(grad_vid1,'B H t c h w -> B t (H c) h w') - - # -- normz -- - if ctx.normalize_bwd: - normz_bwd(ctx,grad_vid0,grad_vid1) - - # -- no grad if ints -- - if itype_bwd == "int" or not(flows.requires_grad): - grad_flows = None - if ctx.flow_ndim == 6 and flows.requires_grad: - grad_flows = grad_flows[:,0].contiguous() - # print(grad_flows.shape) - # print(th.where(flows[0,0]!=0)) - # print(th.where(grad_flows[0,0]!=0)) - # print("-"*20) - # print(th.all(grad_flows==0).item()) - - return grad_vid0,grad_vid1,grad_flows diff --git a/lib/stnls/search/non_local_search.py b/lib/stnls/search/non_local_search.py index a254042..9c5387a 100644 --- a/lib/stnls/search/non_local_search.py +++ b/lib/stnls/search/non_local_search.py @@ -14,100 +14,12 @@ from stnls.utils import extract_pairs # -- local -- -from .utils import shape_vids,allocate_pair,dist_type_select,allocate_vid -from .utils import get_ctx_shell,ensure_flow_shape,shape_flows -from .shared import manage_self,reflect_bounds_warning -from .nls_bwd_impl import nls_backward +from stnls.search.utils import shape_vids,dist_type_select +from stnls.search.utils import get_ctx_shell,shape_flows +from stnls.search.shared import reflect_bounds_warning -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- -# -# Forward Logic -# -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- - -def nls_forward(vid0, vid1, flows, - ws, wt, ps, k, stride0, stride1, - dist_type, dilation, pt, - topk_mode, self_action, - reflect_bounds, full_ws, use_adj, itype): - - # -- unpack -- - # itype = "int" - device = vid0.device - B,HD,T,C,H,W = vid0.shape - patch_offset = 0 if use_adj else -(ps//2) - # print(ps,k,dist_type,topk_mode,self_action,patch_offset) - - # -- derived shapes -- - nH0 = (H-1)//stride0+1 - nW0 = (W-1)//stride0+1 - Q = T*nH0*nW0 - # print(vid0.shape,nH0,nW0,Q) - - # -- settings from distance type -- - dist_type_i,descending,idist_val = dist_type_select(dist_type) - - # -- allocate results -- - W_t = min(2*wt+1,T) - base_shape = (B,HD,Q,W_t,ws,ws) - dists,inds = allocate_pair(base_shape,device,vid0.dtype,idist_val,itype) - - # -- check flows -- - assert flows.shape[3] in [W_t-1,W_t] - - # -- forward -- - if itype == "int": - if flows.dtype != th.int: - flows = flows.round().int() - else: - flows = flows.int() - inds = inds.int() - stride1 = max(1,int(stride1)) - fwd_fxn = stnls_cuda.non_local_search_int_forward - else: - fwd_fxn = stnls_cuda.non_local_search_bilin2d_forward - stride1 = float(stride1) - fwd_fxn(vid0, vid1, flows, dists, inds, - ps, k, stride0, stride1, dilation, pt, - reflect_bounds, full_ws, patch_offset, dist_type_i) - - # -- anchor -- - assert self_action in [None,"anchor","anchor_each","remove","remove_ref_frame"] - anchor_self = False if self_action is None else "anchor" in self_action - if self_action == "anchor": - stnls.nn.anchor_self(dists,inds,stride0,H,W) - elif self_action == "anchor_each": - stnls.nn.anchor_self_time(dists,inds,flows,wt,stride0,H,W) - elif self_action == "remove": - raise NotImplementedError("Not implemented self_action [remove].") - elif self_action == "remove_ref_frame": - assert wt > 0,"Cannot remove ref frame if not searching across time." - dists = dists[...,1:,:,:].contiguous() - inds = inds[...,1:,:,:,:].contiguous() - elif self_action is None: - pass - else: - raise ValueError(f"Uknown option for self_action [{self_action}]") - - # -- topk -- - if topk_mode == "all": - dim = 3 - dists=dists.view(B,HD,Q,W_t*ws*ws) - inds=inds.view(B,HD,Q,W_t*ws*ws,3) - dists,inds = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, - descending=descending) - elif topk_mode == "each": - dists = rearrange(dists,'... wh ww -> ... (wh ww)') - inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3') - dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self) - else: - raise ValueError(f"Unknown topk_mode [{topk_mode}]") - - # -- reshape -- - dists=dists.view(B,HD,T,nH0,nW0,-1) - inds=inds.view(B,HD,T,nH0,nW0,-1,3) - - return dists,inds +# -- implementation -- +from stnls.search.impl.non_local_search import forward,backward # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # @@ -153,11 +65,11 @@ def forward(ctx, vid0, vid1, flows, assert (fH == nH) and (fW == nW) # -- run [optionally batched] forward function -- - dists,inds = nls_forward(vid0, vid1, flows, - ws, wt, ps, k, stride0, stride1, - dist_type, dilation, pt, - topk_mode, self_action, - reflect_bounds, full_ws, use_adj, itype) + dists,inds = forward(vid0, vid1, flows, + ws, wt, ps, k, stride0, stride1, + dist_type, dilation, pt, + topk_mode, self_action, + reflect_bounds, full_ws, use_adj, itype) # -- setup ctx -- dist_type_i = dist_type_select(dist_type)[0] @@ -182,12 +94,7 @@ def forward(ctx, vid0, vid1, flows, @staticmethod def backward(ctx, grad_dists, grad_inds): - # # -- reshape -- - # dists=dists.view(B,HD,T,nH0,nW0,-1) - # inds=inds.view(B,HD,T,nH0,nW0,-1,3) - - grad0,grad1,gfflow = nls_backward(ctx, grad_dists, grad_inds) - + grad0,grad1,gfflow = backward(ctx, grad_dists, grad_inds) return grad0,grad1,gfflow,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None,None diff --git a/lib/stnls/search/paired_refine.py b/lib/stnls/search/paired_refine.py index 8ef002e..b3b7a16 100644 --- a/lib/stnls/search/paired_refine.py +++ b/lib/stnls/search/paired_refine.py @@ -15,33 +15,22 @@ # -- local -- from .utils import shape_frames,allocate_pair_2d,dist_type_select,allocate_vid -from .utils import get_ctx_shell,ensure_flow_shape -from .shared import manage_self,reflect_bounds_warning +from .utils import get_ctx_shell,ensure_flow_shape,ensure_paired_flow_dim +from .shared import reflect_bounds_warning from .paired_bwd_impl import paired_refine_backward -from .batching_utils import run_batched,batching_info -from .paired_utils import paired_vids as _paired_vids - +from ..utils import paired_vids as _paired_vids +# from .paired_utils import paired_vids as _paired_vids # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Forward Logic # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- - -def paired_forward(batchsize,*args): - qshift,nqueries = 0,-1 - return paired_fwd_main(qshift,nqueries,*args) - -def ensure_flow_dim(flow): - if flow.ndim == 4: - flow = flow[:,None] # add nheads - return flow - def paired_refine_forward(frame0, frame1, flow, - wr, ws, ps, k, dist_type, + ws, wr, k, kr, ps, nheads, dist_type, stride0, stride1, dilation, pt, - self_action, reflect_bounds, - full_ws, use_adj, topk_mode, itype): + self_action, reflect_bounds, full_ws, + use_adj, topk_mode, itype): # -- unpack -- device = frame0.device @@ -133,20 +122,20 @@ def paired_refine_forward(frame0, frame1, flow, class PairedRefineFunction(th.autograd.Function): - @staticmethod def forward(ctx, frame0, frame1, flow, - wr, ws, ps, k, nheads=1, + ws, wr, k, kr, ps, nheads=1, dist_type="prod", stride0=4, stride1=1, dilation=1, pt=1, reflect_bounds=True, full_ws=True, self_action=None, use_adj=False, normalize_bwd=False, k_agg=-1, topk_mode="each", itype="float"): """ + Run the non-local search frame0 = [B,T,C,H,W] or [B,HD,T,C,H,W] - ws = search Window Spatial (ws) + """ # -- reshape with heads -- @@ -154,16 +143,18 @@ def forward(ctx, frame0, frame1, flow, device = frame0.device ctx.in_ndim = frame0.ndim frame0,frame1 = shape_frames(nheads,[frame0,frame1]) - # print("frame0.shape: ",frame0.shape) - flow = ensure_flow_dim(flow) - # flow = ensure_flow_shape(flow) + flow = ensure_paired_flow_dim(flow) B,HD,F,H,W = frame0.shape flow = flow.contiguous() reflect_bounds_warning(reflect_bounds) + # -- filter only to kr -- + flows = filter_k(flows,kr) + flows = flows.contiguous() + # -- run [optionally batched] forward function -- dists,inds = paired_refine_forward(frame0, frame1, flow, - ws, ps, k, dist_type, + ws, wr, k, kr, ps, nheads, dist_type, stride0, stride1, dilation, pt, self_action, reflect_bounds, full_ws, use_adj, topk_mode, itype) @@ -172,8 +163,7 @@ def forward(ctx, frame0, frame1, flow, dist_type_i = dist_type_select(dist_type)[0] flow = get_ctx_shell(flow,itype=="int") ctx.save_for_backward(inds,frame0,frame1,flow) - if itype == "int": - ctx.mark_non_differentiable(inds) + if itype == "int": ctx.mark_non_differentiable(inds) ctx.vid_shape = frame0.shape ctx_vars = {"stride0":stride0,"stride1":stride1, "wr":wr,"ps":ps,"pt":pt,"ws":ws,"dil":dilation, @@ -203,7 +193,7 @@ def backward(ctx, grad_dists, grad_inds): class PairedRefine(th.nn.Module): - def __init__(self, wr, ws, ps, k, nheads=1, + def __init__(self, ws, wr, k, kr, ps, nheads=1, dist_type="l2", stride0=4, stride1=1, dilation=1, pt=1, reflect_bounds=True, full_ws=True, self_action=None, use_adj=False, @@ -211,10 +201,11 @@ def __init__(self, wr, ws, ps, k, nheads=1, super().__init__() # -- core search params -- - self.wr = wr self.ws = ws - self.ps = ps + self.wr = wr self.k = k + self.kr = kr + self.ps = ps self.nheads = nheads self.dist_type = dist_type self.stride0 = stride0 @@ -243,14 +234,17 @@ def paired_vids(self, vid0, vid1, flows, wt, skip_self=False): def forward(self, frame0, frame1, flow): assert self.ws > 0,"Must have nonzero spatial search window" return PairedRefineFunction.apply(frame0,frame1,flow, - self.wr,self.ws,self.ps,self.k, - self.nheads,self.dist_type,self.stride0, - self.stride1,self.dilation,self.pt, + self.ws, self.wr, self.k, self.kr, + self.ps, self.nheads, self.dist_type, + self.stride0,self.stride1, + self.dilation,self.pt, self.reflect_bounds,self.full_ws, self.self_action,self.use_adj, self.normalize_bwd,self.k_agg, self.topk_mode,self.itype) + + def flops(self,T,F,H,W): return 0 diff --git a/lib/stnls/search/paired_search.py b/lib/stnls/search/paired_search.py index 5b014f7..85f3126 100644 --- a/lib/stnls/search/paired_search.py +++ b/lib/stnls/search/paired_search.py @@ -16,98 +16,12 @@ # -- local -- from .utils import shape_frames,allocate_pair_2d,dist_type_select,allocate_vid from .utils import get_ctx_shell,ensure_flow_shape +from .utils import ensure_paired_flow_dim as ensure_flow_dim from .shared import manage_self,reflect_bounds_warning -from .paired_bwd_impl import paired_backward -from .batching_utils import run_batched,batching_info -from .paired_utils import paired_vids as _paired_vids +from .utils import paired_vids as _paired_vids - -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- -# -# Forward Logic -# -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- - -def paired_forward(batchsize,*args): - qshift,nqueries = 0,-1 - return paired_fwd_main(qshift,nqueries,*args) - -def ensure_flow_dim(flow): - if flow.ndim == 4: - flow = flow[:,None] # add nheads - return flow - -def paired_forward(frame0, frame1, flow, - ws, ps, k, dist_type, - stride0, stride1, dilation, pt, - self_action, reflect_bounds, - full_ws, use_adj, itype): - - # -- unpack -- - device = frame0.device - B,HD_fr,C,H,W = frame0.shape - HD_flow = flow.shape[1] - # print(frame0.shape,flow.shape) - assert flow.ndim == 5 - HD = max(HD_flow,HD_fr) - patch_offset = 0 if use_adj else -(ps//2) - - # -- derived shapes -- - nH0 = (H-1)//stride0+1 - nW0 = (W-1)//stride0+1 - Q = nH0*nW0 - - # -- settings from distance type -- - dist_type_i,descending,idist_val = dist_type_select(dist_type) - - # -- allocate results -- - base_shape = (B,HD,Q,ws,ws) - dists,inds = allocate_pair_2d(base_shape,device,frame0.dtype,idist_val,itype) - # print("inds.shape: ",inds.shape) - - # -- forward -- - if itype == "int": - flow = flow.round() - inds = inds.int() - stride1 = max(1,int(stride1)) - fwd_fxn = stnls_cuda.paired_search_int_forward - else: - fwd_fxn = stnls_cuda.paired_search_bilin2d_forward - stride1 = float(stride1) - # print(frame0.shape,flow.shape,dists.shape,inds.shape) - fwd_fxn(frame0, frame1, flow, dists, inds, - ps, k, stride0, stride1, dilation, - reflect_bounds, full_ws, patch_offset, dist_type_i) - - # print(frame0.shape,frame1.shape,flow.shape,inds.shape) - # -- compress search region -- - # dists=dists.view(B,HD,Q,-1) - # inds=inds.view(B,HD,Q,-1,2) - # # th.cuda.synchronize() - - # -- anchor -- - assert self_action in [None,"anchor","anchor_each"] - anchor_self = False if self_action is None else "anchor" in self_action - if self_action is None: pass - elif "anchor" in self_action: - stnls.nn.anchor_self(dists[...,None,:,:],inds[...,None,:,:,:],stride0,H,W) - else: - raise ValueError(f"Uknown option for self_action [{self_action}]") - - - # -- topk -- - if k > 0: - dim = 3 - dists=dists.view(B,HD,Q,ws*ws) - inds=inds.view(B,HD,Q,ws*ws,2) - dists,inds = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, - descending=descending) - - # -- reshape -- - dists=dists.reshape(B,HD,1,nH0,nW0,-1) - inds=inds.reshape(B,HD,1,nH0,nW0,-1,2) - - return dists,inds +# -- implementation -- +from stnls.search.impl.paired_search import forward,backward # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # @@ -146,11 +60,11 @@ def forward(ctx, frame0, frame1, flow, reflect_bounds_warning(reflect_bounds) # -- run [optionally batched] forward function -- - dists,inds = paired_forward(frame0, frame1, flow, - ws, ps, k, dist_type, - stride0, stride1, dilation, pt, - self_action, reflect_bounds, full_ws, - use_adj, itype) + dists,inds = forward(frame0, frame1, flow, + ws, ps, k, dist_type, + stride0, stride1, dilation, pt, + self_action, reflect_bounds, full_ws, + use_adj, itype) # -- setup ctx -- dist_type_i = dist_type_select(dist_type)[0] @@ -173,7 +87,7 @@ def forward(ctx, frame0, frame1, flow, @staticmethod def backward(ctx, grad_dists, grad_inds): - grad0,grad1,gflow = paired_backward(ctx, grad_dists, grad_inds) + grad0,grad1,gflow = backward(ctx, grad_dists, grad_inds) return grad0,grad1,gflow,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None,None diff --git a/lib/stnls/search/ref_bwd_impl.py b/lib/stnls/search/ref_bwd_impl.py deleted file mode 100644 index 22c04f5..0000000 --- a/lib/stnls/search/ref_bwd_impl.py +++ /dev/null @@ -1,75 +0,0 @@ - -# -- python -- -import torch as th -import numpy as np -from einops import rearrange - -# -- cpp cuda kernel -- -import stnls_cuda - -# -- package -- -import stnls - -# -- local -- -from .utils import shape_vids,allocate_pair,dist_type_select,allocate_vid -from .utils import get_inds,allocate_grad_flows -from .shared import manage_self,normz_bwd - -def ref_backward(ctx, grad_dists, grad_inds): - - # -- populate names -- - inds,vid0,vid1,kselect,reflect = ctx.saved_tensors - itype_bwd = ctx.itype_bwd - device = grad_dists.device - - # -- allocate grads -- - grad_vid0 = allocate_vid(ctx.vid_shape,device) - grad_vid1 = allocate_vid(ctx.vid_shape,device) - grad_flows = allocate_grad_flows(itype_bwd,ctx.flows_shape,device) - - # -- restrict to k_agg -- - if ctx.k_agg > 0: - grad_dists = grad_dists[...,:ctx.k_agg] - inds = inds[...,:ctx.k_agg,:] - - # -- ensure contiguous -- - grad_dists = grad_dists.contiguous() - inds = get_inds(inds,itype_bwd) - patch_offset = 0 if ctx.use_adj else -(ctx.ps//2) - - # -- backward pass with increasing complexity -- - # print(inds[...,1:].min().item(),inds[...,1:].max().item()) - # print("gvid0: ",grad_vid0.shape) - # print(kselect.min().item(),kselect.max().item()) - if itype_bwd == "int": - bwd_fxn = stnls_cuda.non_local_search_int_vid_backward - bwd_fxn(grad_vid0,grad_vid1,vid0,vid1,grad_dists,inds, - ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - ctx.reflect_bounds,patch_offset,ctx.dist_type_i) - elif not(ctx.flows_requires_grad): - bwd_fxn = stnls_cuda.non_local_search_bilin2d_vid_backward - bwd_fxn(grad_vid0,grad_vid1,vid0,vid1,grad_dists,inds, - ctx.wt,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - ctx.reflect_bounds,patch_offset,ctx.dist_type_i) - else: - bwd_fxn = stnls_cuda.refinement_bilin2d_vidflows_backward - bwd_fxn(grad_vid0,grad_vid1,grad_flows, - vid0,vid1,grad_dists,grad_inds,inds, - kselect,reflect, - ctx.ws,ctx.ps,ctx.pt,ctx.stride0,ctx.dil, - ctx.reflect_bounds,patch_offset,ctx.dist_type_i) - th.cuda.synchronize() - - # -- finalize shape -- - grad_vid0 = rearrange(grad_vid0,'B H t c h w -> B t (H c) h w') - grad_vid1 = rearrange(grad_vid1,'B H t c h w -> B t (H c) h w') - - # -- normz -- - if ctx.normalize_bwd: - normz_bwd(ctx,grad_vid0,grad_vid1) - - # -- no grad if ints -- - if itype_bwd == "int": grad_flows = None - - return grad_vid0,grad_vid1,grad_flows - diff --git a/lib/stnls/search/refinement.py b/lib/stnls/search/refinement.py index bff7b82..32134f7 100644 --- a/lib/stnls/search/refinement.py +++ b/lib/stnls/search/refinement.py @@ -13,146 +13,12 @@ from stnls.utils import extract_pairs # -- local -- -from .utils import filter_k -from .utils import shape_vids,dist_type_select -from .utils import allocate_pair,allocate_vid -from .utils import get_ctx_qinds -from .shared import manage_self,reflect_bounds_warning -# from .nls_bwd_impl import nls_backward -from .ref_bwd_impl import ref_backward +from stnls.search.utils import filter_k,shape_vids,dist_type_select +from stnls.search.utils import shape_refinement_flows as shape_flows +from stnls.search.shared import reflect_bounds_warning -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- -# -# Forward Logic -# -# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- - -def refine_forward(vid0, vid1, flows, - ws, wr, k, kr, ps, stride0, stride1, dilation, pt, - dist_type, restricted_radius, - reflect_bounds, full_ws, - topk_mode, self_action, patch_offset, itype_fwd): - - # -- fix negative Q -- - # if Q > 0: - # flows = flows[:,:,qshift:qshift+Q].contiguous() - B,HD,T,nH,nW,Ks,_ = flows.shape - Q = T*nH*nW - - # -- settings from distance type -- - dist_type_i,descending,idist_val = dist_type_select(dist_type) - - # -- allocate results -- - device = flows.device - B,HD,T,nH,nW,Ks = flows.shape[:-1] - base_shape = (B,HD,T,nH,nW,Ks,wr,wr) - # print(base_shape,flows.shape) - dists,inds = allocate_pair(base_shape,device,vid0.dtype,idist_val,itype_fwd) - - # -- allow for int fwd when actually float -- - if itype_fwd == "int": - inds = inds.int() - if flows.dtype == th.float: - flows = flows.round().int() - kselect = th.zeros(0,device=flows.device) - reflect = th.zeros(0,device=flows.device) - else: - kselect = th.ones_like(dists).int() - reflect = th.zeros_like(flows[...,:2]).bool() - - # -- run -- - if itype_fwd == "int": - stride1 = int(max(1,int(stride1))) - fwd_fxn = stnls_cuda.refinement_int_forward - fwd_fxn(vid0, vid1, flows, dists, inds, - ws, ps, stride0, stride1, dilation, pt, - restricted_radius, reflect_bounds, full_ws, - patch_offset, dist_type_i) - else: - stride1 = float(stride1) - fwd_fxn = stnls_cuda.refinement_bilin2d_forward - fwd_fxn(vid0, vid1, flows, dists, inds, - kselect, reflect, - ws, ps, stride0, stride1, dilation, pt, - restricted_radius, reflect_bounds, full_ws, - patch_offset, dist_type_i) - - # print(inds[0,0,0,0,0,:2]) - # -- no negative -- - # if th.any(flows[0]<0): - # print(flows[0]) - # print(inds[0]) - - # -- manage self dists -- - # # H,W = vid0.shape[-2:] - # anchor_self = self_action == "anchor" - # remove_self = self_action == "remove" - # assert anchor_self is False - # assert remove_self is False - # return_order = not(kselect is None) - # dists,inds,kselect = manage_self_ksel(dists,inds,kselect,self_action,wr) - # # kselect = kselect[...,1:] if remove_self else kselect - # # dists.shape = (B,H,Q,Ks,wr*wr) - if not(self_action is None) and "anchor" in self_action: - H,W = vid0.shape[-2:] - stnls.nn.anchor_self_refine(dists,inds,flows,stride0,H,W) - else: - assert self_action == None - - # -- topk -- - assert self_action in [None,"anchor","anchor_each"] - anchor_self = False if self_action is None else "anchor" in self_action - if topk_mode == "all": - dim = 3 - dists=dists.view(B,HD,Q,Ks*wr*wr) - inds=inds.view(B,HD,Q,Ks*wr*wr,3) - dists,inds,order = stnls.nn.topk(dists,inds,k,dim=dim,anchor=anchor_self, - descending=descending,unique=False, - return_order=True) - if kselect.ndim > 1: - # print("kselect.shape: ",kselect.shape,order.shape) - kselect = kselect.view(B,HD,Q,Ks*wr*wr) if not(kselect is None) else kselect - # print("kselect.shape: ",kselect.shape,order.shape) - kselect = stnls.nn.topk_f.apply_topk(kselect,order,dim) - elif topk_mode == "each": - # print(dists.shape,kselect.shape) - dists = rearrange(dists,'... wh ww -> ... (wh ww)') - inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3') - dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self) - if kselect.ndim > 1: - kselect = rearrange(kselect,'... wh ww -> ... (wh ww)') - kselect = kselect[...,:k] # all same across dim - else: - raise ValueError(f"Unknown topk_mode [{topk_mode}]") - - - # -- reshape for output -- - dists=dists.view(B,HD,T,nH,nW,-1) - inds=inds.view(B,HD,T,nH,nW,-1,3) - kselect = kselect.view(B,HD,T,nH,nW,-1) if not(kselect is None) else kselect - # print("kselect.shape,reflect.shape: ",kselect.shape,reflect.shape) - # print(flows.shape,inds.shape,kselect.shape) - # print(th.cat([flows[0,0,...,[0]],inds[0,0,...,[0]],kselect[0,0,...,None]],-1)) - - return dists,inds,kselect,reflect - -def shape_flows(nheads,flows,B,nH,nW): - # print(flows.shape) - if flows.ndim == 4: - B,HD,Q,tr = flows.shape - flows=rearrange(flows,'b hd (t nh nw) thr -> b hd t nh nw thr',nh=nH,nw=nW) - # elif flows.ndim == 3: - # BHD,Q,tr = flows.shape - # shape_str = '(b hd) (t nh nw) k thr -> b hd t nh nw k thr' - # flows=rearrange(flows,,b=B,nh=nH,nw=nW) - elif flows.ndim == 5: - B,HD,Q,K,tr = flows.shape - shape_str = 'b hd (t nh nw) k thr -> b hd t nh nw k thr' - flows=rearrange(flows,shape_str,b=B,nh=nH,nw=nW) - elif flows.ndim == 6: - flows=rearrange(flows,'(b hd) t nh nw thr -> (b hd) t nh nw thr',b=B) - assert flows.ndim == 7 - return flows +# -- implementation -- +from stnls.search.impl.refinement import forward,backward # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # @@ -201,11 +67,11 @@ def forward(ctx, vid0, vid1, flows, # flows_t[...,0] = flows_t[...,0].round() # -- run fwd pass -- - dists,inds,kselect,reflect = refine_forward(vid0, vid1, flows, - ws, wr, k, kr, ps, stride0, stride1, - dilation, pt, dist_type, restricted_radius, - reflect_bounds, full_ws, topk_mode, - self_action, patch_offset, itype) + dists,inds,kselect,reflect = forward(vid0, vid1, flows, + ws, wr, k, kr, ps, stride0, stride1, + dilation, pt, dist_type, restricted_radius, + reflect_bounds, full_ws, topk_mode, + self_action, patch_offset, itype) # -- reshape -- dists=dists.view(B,HD,T,nH,nW,-1) @@ -231,8 +97,7 @@ def forward(ctx, vid0, vid1, flows, @staticmethod def backward(ctx, grad_dists, grad_inds): - # print(grad_inds.abs().mean()) - grad0,grad1,grad_flows = ref_backward(ctx, grad_dists, grad_inds) + grad0,grad1,grad_flows = backward(ctx, grad_dists, grad_inds) return grad0,grad1,grad_flows,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None,\ None,None,None,None,None,None,None,None,None,None,None,None,None diff --git a/lib/stnls/search/utils.py b/lib/stnls/search/utils.py index ae278fc..75866bf 100644 --- a/lib/stnls/search/utils.py +++ b/lib/stnls/search/utils.py @@ -123,6 +123,12 @@ def filter_k(inds,kr,k=None): return inds[...,:Ks,:].contiguous() +def ensure_paired_flow_dim(flow): + if flow.ndim == 4: + flow = flow[:,None] # add nheads + return flow + + def ensure_flow_shape(flow): if flow.ndim == 5: B,T,_,H,W = flow.shape @@ -147,6 +153,25 @@ def shape_flows(nheads,flows): msg = f"Input flows are wrong dimension. Must be 6 or 7 but is [{ndim}]" raise ValueError(msg) +def shape_refinement_flows(nheads,flows,B,nH,nW): + # print(flows.shape) + if flows.ndim == 4: + B,HD,Q,tr = flows.shape + flows=rearrange(flows,'b hd (t nh nw) thr -> b hd t nh nw thr',nh=nH,nw=nW) + # elif flows.ndim == 3: + # BHD,Q,tr = flows.shape + # shape_str = '(b hd) (t nh nw) k thr -> b hd t nh nw k thr' + # flows=rearrange(flows,,b=B,nh=nH,nw=nW) + elif flows.ndim == 5: + B,HD,Q,K,tr = flows.shape + shape_str = 'b hd (t nh nw) k thr -> b hd t nh nw k thr' + flows=rearrange(flows,shape_str,b=B,nh=nH,nw=nW) + elif flows.ndim == 6: + flows=rearrange(flows,'(b hd) t nh nw thr -> (b hd) t nh nw thr',b=B) + assert flows.ndim == 7 + return flows + + def shape_vids(nheads,vids): _vids = [] for vid in vids: @@ -273,3 +298,129 @@ def wrap(vid0,vid1,fflow,bflow,inds,afflow,abflow): def wrap(vid0,vid1,fflow,bflow,inds,afflow,abflow): return search(vid0,vid1,fflow,bflow) return wrap + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# Paired Utils +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +def get_time_window_inds(ti,wt,T): + swap = False + t_inc = 0 + prev_t = ti + t_shift = min(0,ti-wt) + max(0,ti + wt - (T-1)) + t_max = min(T-1,ti + wt - t_shift); + # print(t_shift,t_max) + tj = ti + inds = [] + for _tj in range(2*wt+1): + # -- update search frame -- + prev_t = tj + tj = prev_t + t_inc + swap = tj > t_max + t_inc = 1 if (t_inc == 0) else t_inc + t_inc = -1 if swap else t_inc + tj = ti-1 if swap else tj + prev_t = ti if swap else prev_t + # print(ti,tj,t_inc,swap) + inds.append(tj) + return inds + +def get_flows(flows): + assert flows.ndim in [6,7] + if flows.ndim == 6: + flows = flows[:,None] + return flows + +def paired_vids(forward, vid0, vid1, flows, wt, skip_self=False): + dists,inds = [],[] + T = vid0.shape[1] + flows = get_flows(flows) + zflow = th.zeros_like(flows[:,:,0,0]) + for ti in range(T): + t_grid = get_time_window_inds(ti,wt,T) + dists_i,inds_i = [],[] + for _tj in range(2*wt+1): + + # -- update search frame -- + tj = t_grid[_tj] + if (ti == tj) and skip_self: continue + frame0 = vid0[:,ti] + frame1 = vid1[:,tj] + if _tj > 0: flow = flows[:,:,ti,_tj-1] + else: flow = zflow + flow = flow.float() + dists_ij,inds_ij = forward(frame0,frame1,flow) + inds_t = (tj-ti)*th.ones_like(inds_ij[...,[0]]) + inds_ij = th.cat([inds_t,inds_ij],-1) + dists_i.append(dists_ij) + inds_i.append(inds_ij) + # -- stack across K -- + dists_i = th.cat(dists_i,-1) + inds_i = th.cat(inds_i,-2) + dists.append(dists_i) + inds.append(inds_i) + # -- stack across time -- + dists = th.cat(dists,-4) + inds = th.cat(inds,-5) + # print("inds.shape: ",inds.shape) + return dists,inds + +def paired_vids_old(forward, vid0, vid1, acc_flows, wt, skip_self=False): + dists,inds = [],[] + T = vid0.shape[1] + zflow = th.zeros_like(acc_flows.fflow[:,0,0]) + for ti in range(T): + # if ti != 1: continue + + swap = False + t_inc = 0 + prev_t = ti + t_shift = min(0,ti-wt) + max(0,ti + wt - (T-1)) + t_max = min(T-1,ti + wt - t_shift); + # print(t_shift,t_max) + tj = ti + + dists_i,inds_i = [],[] + for _tj in range(2*wt+1): + + # -- update search frame -- + prev_t = tj + tj = prev_t + t_inc + swap = tj > t_max + t_inc = 1 if (t_inc == 0) else t_inc + t_inc = -1 if swap else t_inc + tj = ti-1 if swap else tj + prev_t = ti if swap else prev_t + # print(ti,tj,t_inc,swap) + + frame0 = vid0[:,ti] + frame1 = vid1[:,tj] + if (ti == tj) and skip_self: continue + if ti == tj: + flow = zflow + elif ti < tj: + # print("fwd: ",ti,tj,tj-ti-1) + # flow = acc_flows.fflow[:,tj - ti - 1] + flow = acc_flows.fflow[:,ti,tj-ti-1] + elif ti > tj: + # print("bwd: ",ti,tj,ti-tj-1) + # flow = acc_flows.bflow[:,ti - tj - 1] + flow = acc_flows.bflow[:,ti,ti-tj-1] + flow = flow.float() + dists_ij,inds_ij = forward(frame0,frame1,flow) + inds_t = tj*th.ones_like(inds_ij[...,[0]]) + inds_ij = th.cat([inds_t,inds_ij],-1) + # print("inds_ij.shape: ",inds_ij.shape,inds_t.shape) + dists_i.append(dists_ij) + inds_i.append(inds_ij) + dists_i = th.cat(dists_i,-1) + inds_i = th.cat(inds_i,-2) + dists.append(dists_i) + inds.append(inds_i) + dists = th.cat(dists,-2) + inds = th.cat(inds,-3) + # print("inds.shape: ",inds.shape) + return dists,inds + diff --git a/lib/stnls/testing/__init__.py b/lib/stnls/testing/__init__.py index 3054639..fb82312 100644 --- a/lib/stnls/testing/__init__.py +++ b/lib/stnls/testing/__init__.py @@ -4,3 +4,47 @@ from . import gradcheck find_duplicate_inds = find_duplicate_inds_f.run + + + +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- +# +# Misc Functions +# +# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + +import torch as th + + +def check_shuffled_inds(inds_gt,inds_te,eps=1e-3): + inds_gt,inds_te = 1.*inds_gt,1.*inds_te + args = th.where(th.mean(th.abs(inds_gt-inds_te),dim=-1)>eps) + i0,i1 = [],[] + for i in range(3): + i0.append(inds_gt[...,i][args]) + i1.append(inds_te[...,i][args]) + i0 = th.stack(i0,-1) + i1 = th.stack(i1,-1) + cdist = th.cdist(i0[None,:],i1[None,:])**2/3. + idiffs = th.cdist(i0[None,:],i1[None,:])[0] + mins = th.min(idiffs,1).values + diff = th.mean(mins).item() + return diff < eps + +def int_spaced_vid(B,T,F,H,W): + device = "cuda:0" + dtype = th.float32 + grid_y, grid_x = th.meshgrid(th.arange(0, H, dtype=dtype, device=device), + th.arange(0, W, dtype=dtype, device=device)) + grid = th.stack((grid_x, grid_y), 0).float()[None,:] # 2, W(x), H(y) + vid = [] + for ti in range(T): + g0 = grid[:,[0]].repeat(B,F,1,1)/W-0.5 + g1 = grid[:,[1]].repeat(B,F,1,1)/H-0.5 + # g0 += th.rand_like(g0) + # g1 += th.rand_like(g1) + tN = (ti+1)*th.ones_like(g0) + vid.append(g0*g1*tN) # noise less than int + vid = th.stack(vid,1) + return vid + diff --git a/lib/stnls/testing/gradcheck.py b/lib/stnls/testing/gradcheck.py index 3d58fe2..3225a8b 100644 --- a/lib/stnls/testing/gradcheck.py +++ b/lib/stnls/testing/gradcheck.py @@ -10,6 +10,7 @@ """ import torch +import torch as th def get_num_jacobian(fxn,inputs,eps=1e-3,nreps=1): from torch.autograd.gradcheck import _get_numerical_jacobian @@ -30,3 +31,61 @@ def get_gradcheck_pair(fxn,inputs,eps=1e-3): num = get_num_jacobian(fxn,inputs,eps=1e-3) ana = get_ana_jacobian(fxn,inputs) return num,ana + +def gradcheck_skip_nan_unstable(fxn,inputs, rtol=1e-05, atol=1e-08, + nreps=3, num_eps=5e-4, unstable_eps=1e-2): + num = get_num_jacobian_skip_unstable(fxn,inputs,eps=num_eps, + nreps=nreps,unstable_eps=unstable_eps) + ana = get_ana_jacobian(fxn,inputs) + args = th.where(th.logical_and(~th.isnan(num),num.abs()>0)) + args1 = th.where(th.abs(num[args]-ana[args])>1e-2)[0] + # print("ana: ",ana[47,573:575]) + # print(num[:5,:5]) + # print(ana[:5,:5]) + # print(num[-5:,-5:]) + # print(ana[-5:,-5:]) + # # print(num.shape) + # print(num[args][args1][:20]) + # print(ana[args][args1][:20]) + # print([args[i][args1] for i in range(2)]) + return th.allclose(num[args],ana[args],atol=atol,rtol=rtol) + +def gradcheck_skipnan(fxn,inputs, rtol=1e-05, atol=1e-08, nreps=1, num_eps=5e-4): + num = get_num_jacobian(fxn,inputs,eps=num_eps,nreps=nreps) + ana = get_ana_jacobian(fxn,inputs) + args = th.where(th.logical_and(~th.isnan(num),num.abs()>0)) + args1 = th.where(th.abs(num[args]-ana[args])>1e-2)[0] + # print(num[-5:,-5:]) + # print(ana[-5:,-5:]) + # print(num.shape) + # print(num[args][args1][:20]) + # print(ana[args][args1][:20]) + # print([args[i][args1] for i in range(2)]) + return th.allclose(num[args],ana[args],atol=atol,rtol=rtol) + +def get_num_jacobian_skip_unstable(fxn,inputs,eps=1e-3,nreps=1,unstable_eps=1e-2): + from torch.autograd.gradcheck import _get_numerical_jacobian + nums = [] + for i in range(nreps): + eps_i = eps * (1 + i*eps) + num = _get_numerical_jacobian(fxn, (inputs,), + eps=eps_i, is_forward_ad=False)[0][0] + nums.append(num) + + delta = th.zeros_like(nums[0]) + for i in range(nreps): + # print(nums[i][47,573:575]) + for j in range(nreps): + if i >= j: continue + # print(i,j) + delta += th.abs(nums[i] - nums[j]) + # print(delta) + # print(delta[~th.isnan(delta)].min(),delta[~th.isnan(delta)].max()) + # print("Percentage unstable: ",100*th.mean(1.*(delta > unstable_eps)).item()) + unstable = th.where(th.logical_or(delta > unstable_eps,th.isnan(delta))) + num = th.mean(th.stack(nums),dim=0) + num[unstable] = th.nan + # print(num) + # print(nums[0]) + return num + diff --git a/tests/search/test_paired_refine.py b/tests/search/test_paired_refine.py new file mode 100644 index 0000000..6e3b76a --- /dev/null +++ b/tests/search/test_paired_refine.py @@ -0,0 +1,382 @@ + +# -- python -- +import sys + +# -- data mgnmt -- +from pathlib import Path +from easydict import EasyDict as edict + +# -- testing -- +import pytest +import random + +# -- linalg -- +import torch as th +import numpy as np +from einops import rearrange,repeat + +# -- stnls -- +import stnls +import stnls.utils.gpu_mem as gpu_mem +from stnls.utils.pads import comp_pads +from stnls.testing.gradcheck import gradcheck_skipnan,gradcheck_skip_nan_unstable +from stnls.testing.gradcheck import check_shuffled_inds +from stnls.testing import int_spaced_vid + +# -- test func -- +from torch.nn.functional import fold,unfold,pad +from torchvision.transforms.functional import center_crop + +# -- paths -- +SAVE_DIR = Path("./output/tests/non_local_search") + + +def pytest_generate_tests(metafunc): + seed = 123 + th.manual_seed(seed) + np.random.seed(seed) + test_lists = {"ws":[3],"wt":[1],"k":[-1],"wr":[1],"kr":[-1],"pt":[1], + "ps":[3],"stride0":[1],"stride1":[1],"dilation":[1], + "self_action":["anchor_each"],"nheads":[1],"seed":[0], + "dist_type":["prod","l2"],"itype":["int","float"], + "reflect_bounds":[True]} + for key,val in test_lists.items(): + if key in metafunc.fixturenames: + metafunc.parametrize(key,val) + +def set_seed(seed): + th.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def test_fwd_match_refine(ws,wt,wr,kr,k,ps,pt,stride0,stride1,dilation, + self_action,nheads,dist_type,itype,seed,reflect_bounds): + # -- load data -- + device = "cuda:0" + set_seed(seed) + + B,HD,T,F,H,W = 1,nheads,3,1,10,10 + vid = th.ones((B,T,HD*F,H,W),device=device) + vid0 = th.rand_like(vid)#.requires_grad_(True) + vid1 = th.rand_like(vid)#.requires_grad_(True) + + # -- create inds -- + flows = th.ones((B,HD,T,nH,nW,K,3))+0.1 + flows = th.rand_like(flows)/2.+1.1 + tgrid = th.arange(0,T).view(1,1,T,1,1,1) + flows[...,0] = th.randint(0,T,size=flows[...,0].shape)-tgrid + flows[...,1:] = th.rand_like(flows[...,1:])/2.+0.2 + # flows[...,1:] = -flows[...,1:] + not_int = th.all(th.abs(flows[...,1:].round() - flows[...,1:])>1e-5).item() + assert not_int,"Gradcheck only works _not_ near an int." + flows = flows.to(vid0.device) + # srch_inds = srch_inds.requires_grad_(True) + + # -- exec fold fxns -- + refine_gt = stnls.search.RefineSearch(ws, wt, wr, k_refine, kr, ps, nheads, + dilation=dilation, + stride0=stride0, stride1=stride1, + reflect_bounds=reflect_bounds,full_ws=True, + self_action=self_action, + dist_type=dist_type,itype=itype) + refine_te = stnls.search.PairedRefine(wr, ws, ps, k_refine, kr, nheads, + dilation=dilation, + stride0=stride0, stride1=stride1, + reflect_bounds=reflect_bounds,full_ws=True, + self_action=self_action, + dist_type=dist_type,itype=itype) + + # -- test api -- + dists_gt,inds_gt = refine_gt(vid0, vid1, flows) + dists_te,inds_te = refine_te.paired_vids(vid0, vid1, flows, wt) + + # -- compare -- + assert th.allclose(dists_te,dists_gt,1e-3,1e-3,equal_nan=True) + +def test_fwd_match_search(ws,wt,kr,k,ps,pt,stride0,stride1,dilation, + self_action,nheads,dist_type,itype,seed,reflect_bounds): + """ + + Test the CUDA code with torch code + + Forward Pass + + """ + + # -- init vars -- + device = "cuda:0" + wr = 1 + k_refine = -1 + set_seed(seed) + + # -- load data -- + B,HD,T,F,H,W = 1,nheads,3,1,10,10 + vid = th.ones((B,T,HD*F,H,W),device=device) + vid0 = th.rand_like(vid)#.requires_grad_(True) + vid1 = th.rand_like(vid)#.requires_grad_(True) + + # -- compute flow -- + nH,nW = (H-1)//stride0+1,(W-1)//stride0+1 + W_t = min(2*wt+1,T) + flows = th.ones((B,HD,T,W_t-1,2,nH,nW)).cuda()/2. + flows = th.rand_like(flows)/2.+th.randint_like(flows,-3,3)+0.2 + # flows = flows.requires_grad_(True) + + # -- exec fold fxns -- + search = stnls.search.PairedSearch(ws, wt, ps, k, nheads, + dilation=dilation, + stride0=stride0, stride1=stride1, + reflect_bounds=reflect_bounds,full_ws=True, + self_action=self_action, + dist_type=dist_type,itype=itype) + refine = stnls.search.PairedRefine(ws, wt, wr, k_refine, kr, ps, nheads, + dilation=dilation, + stride0=stride0, stride1=stride1, + reflect_bounds=reflect_bounds,full_ws=True, + self_action=self_action, + dist_type=dist_type,itype=itype) + + # -- test api -- + dists_gt,inds_gt = search.paired_vids(vid0, vid1, flows, wt) + dists_te,inds_te = refine.paired_vids(vid0, vid1, inds_gt, wt) + + # -- compare -- + assert th.allclose(dists_te,dists_gt,1e-3,1e-3,equal_nan=True) + +# @pytest.mark.slow +def test_refine_noshuffle_bwd(ws,wt,wr,kr,ps,pt,stride0,stride1,dilation, + self_action,k,nheads,dist_type,itype,seed,reflect_bounds): + """ + + Test the CUDA code with torch code + + Forward Pass + + """ + + + # -- init vars -- + device = "cuda:0" + set_seed(seed) + + # -- shapes -- + W_t = 2*wt+1 + k,kr = W_t*ws*ws,-1 + HD,K = nheads,k + + # -- load data -- + B,T,F,H,W = 1,3,1,8,8 + W_t = 2*wt+1 + nH,nW = (H-1)//stride0+1,(W-1)//stride0+1 + vid0 = int_spaced_vid(B,T,F,H,W) + vid0 = int_spaced_vid(B,T,F,H,W) + vid1 = int_spaced_vid(B,T,F,H,W) + vid0 = th.rand_like(vid0)/2.+0.2 + vid1 = th.rand_like(vid0)/2.+0.2 + + # -- init for grads -- + vid0.requires_grad_(True) + vid1.requires_grad_(True) + + # -- exec fold fxns -- + wr = 1 + refine = stnls.search.PairedRefine(ws, wt, wr, -1, kr, ps, nheads, + dilation=dil,stride0=stride0, stride1=stride1, + reflect_bounds=reflect_bounds,full_ws=full_ws, + self_action=self_action, + dist_type=dist_type,itype=itype,topk_mode="all") + + # -- create inds -- + srch_inds = th.ones((B,HD,T,nH,nW,K,3))+0.1 + srch_inds = th.rand_like(srch_inds)/2.+1.1 + tgrid = th.arange(0,T).view(1,1,T,1,1,1) + srch_inds[...,0] = th.randint(0,T,size=srch_inds[...,0].shape)-tgrid + srch_inds[...,1:] = th.rand_like(srch_inds[...,1:])/2.+0.2 + # srch_inds[...,1:] = -srch_inds[...,1:] + not_int = th.all(th.abs(srch_inds[...,1:].round() - srch_inds[...,1:])>1e-5).item() + assert not_int,"Gradcheck only works _not_ near an int." + srch_inds = srch_inds.to(vid0.device) + # srch_inds = srch_inds.requires_grad_(True) + + # -- run refinement -- + ref_dists,ref_inds = refine(vid0,vid1,srch_inds) + + # -- autograd -- + fxn = lambda vid0: refine(vid0,vid1,srch_inds)[0] + # assert gradcheck_skip_nan_unstable(fxn,vid0, atol=1e-02, num_eps=1e-5) + assert gradcheck_skipnan(fxn,vid0, atol=1e-02, num_eps=1e-3) + fxn = lambda vid1: refine(vid0,vid1,srch_inds)[0] + assert gradcheck_skipnan(fxn,vid1, atol=1e-02, num_eps=1e-3) + # assert gradcheck_skip_nan_unstable(fxn,vid1, atol=1e-02, num_eps=1e-5) + + # -- autograd check for indices -- + if itype == "float": + srch_inds_t = srch_inds[...,[0]] + srch_inds_sp = srch_inds[...,1:].requires_grad_(True) + def fxn(srch_inds_sp): + srch_inds = th.cat([srch_inds_t,srch_inds_sp],-1).requires_grad_(True) + return refine(vid0,vid1,srch_inds)[0] + # assert gradcheck_skipnan(fxn, srch_inds_sp, atol=1e-02, num_eps=1e-5) + # assert gradcheck_skip_nan_unstable(fxn, srch_inds_sp, atol=1e-02, + # nreps=3, num_eps=1e-3) + + def fxn(srch_inds_sp): + srch_inds = th.cat([srch_inds_t,srch_inds_sp],-1).requires_grad_(True) + return refine(vid0,vid1,srch_inds)[1] + # assert gradcheck_skipnan(fxn, srch_inds_sp, atol=1e-02, num_eps=1e-5) + assert gradcheck_skip_nan_unstable(fxn, srch_inds_sp, atol=1e-02, + nreps=3, num_eps=1e-3) + + + +# def test_anchor_fwd(ws,wt,wr,ps,stride0,stride1,dilation, +# nheads,dist_type,itype,seed,reflect_bounds): + +# """ + +# Test the CUDA code with torch code + +# Forward Pass + +# """ + +# # -- init vars -- +# dil = dilation +# pt = 1 +# device = "cuda:0" +# clean_flow = True +# comp_flow = False +# use_adj = False +# full_ws = False +# ext = "jpg" +# dnames = ["davis_baseball_64x64","davis_baseball_64x64"] +# topk_mode = "each" +# kr = -1 +# set_seed(seed) + +# # -- load data -- +# vid = stnls.testing.data.load_burst_batch("./data/",dnames,ext=ext) +# vid = vid.to(device)[:,:5,:3,::2,::2].contiguous() +# vid = repeat(vid,'b t c h w -> b t (r c) h w',r=12)[:,:32].contiguous() +# vid0 = th.rand_like(vid)-0.5 +# vid1 = th.rand_like(vid)-0.5 + +# # -- compute flow -- +# B,T,F,H,W = vid.shape +# W_t = 2*wt+1 +# nH,nW = (H-1)//stride0+1,(W-1)//stride0+1 +# flows = 2*th.rand((B,T,W_t-1,2,nH,nW)).to(vid0.device) + +# # -- exec fold fxns -- +# k0 = 5 +# search0 = stnls.search.NonLocalSearch(ws, wt, ps, k0, nheads, +# dilation=dil,stride0=stride0, stride1=stride1, +# reflect_bounds=reflect_bounds,full_ws=True, +# self_action=None,use_adj=use_adj, +# dist_type=dist_type,topk_mode="each", +# itype=itype) +# k = 1 +# refine0 = stnls.search.RefineSearch(ws, wt, wr, k, kr, ps, nheads, +# dilation=dil,stride0=stride0, stride1=stride1, +# reflect_bounds=reflect_bounds,full_ws=False, +# self_action="anchor_each",use_adj=use_adj, +# dist_type=dist_type,topk_mode="each",itype=itype) +# k = 1 +# refine1 = stnls.search.RefineSearch(ws, wt, wr, k, kr, ps, nheads, +# dilation=dil,stride0=stride0, stride1=stride1, +# reflect_bounds=reflect_bounds,full_ws=True, +# self_action="anchor_each",use_adj=use_adj, +# dist_type=dist_type,topk_mode="each",itype=itype) + + +# # -- exec search -- +# HD = nheads +# vshape = (B,HD,T,nH,nW,W_t*k0) +# dists0,inds0 = search0(vid0,vid1,flows) +# dists0,inds0 = dists0.view(vshape),inds0.view(vshape+(3,)) + +# # -- exec refine -- +# dists_r0,inds_r0 = refine0(vid0,vid1,inds0) +# dists_r1,inds_r1 = refine1(vid0,vid1,inds0) +# # print(th.stack([dists0,dists_r0],-1)) +# # print(th.stack([inds0,inds_r0],-1)) +# # print(th.stack([dists0,dists_r1],-1)) +# # args = th.where(th.abs(dists0-dists_r0)>1e-3) +# # print(dists0[args][:10]) +# # print(dists_r0[args][:10]) + +# # -- compare -- +# assert th.allclose(dists0,dists_r0,1e-2,1e-3,equal_nan=True) +# assert th.allclose(inds0,inds_r0,1e-3,1e-3,equal_nan=True) +# assert th.allclose(dists0,dists_r1,1e-3,1e-3,equal_nan=True) +# assert th.allclose(inds0,inds_r1,1e-3,1e-3,equal_nan=True) + + +# def test_fwd_topk(ws,wt,wr,ps,stride0,stride1,dilation,dist_type,seed,reflect_bounds): + +# """ + +# Test the CUDA code with torch code + +# Forward Pass + +# """ + +# # -- init vars -- +# dil = dilation +# pt = 1 +# device = "cuda:0" +# clean_flow = True +# comp_flow = False +# use_adj = False +# full_ws = True +# ext = "jpg" +# dnames = ["davis_baseball_64x64","davis_baseball_64x64"] +# topk_mode = "each" +# itype = "float" +# nheads = 1 +# kr = -1 +# set_seed(seed) + +# # -- load data -- +# vid = stnls.testing.data.load_burst_batch("./data/",dnames,ext=ext) +# vid = vid.to(device)[:,:5,:3,:,:].contiguous() +# vid = repeat(vid,'b t c h w -> b t (r c) h w',r=12)[:,:32].contiguous() +# vid0 = th.rand_like(vid)-0.5 +# vid1 = th.rand_like(vid)-0.5 + +# # -- compute flow -- +# B,T,F,H,W = vid.shape +# W_t = 2*wt+1 +# nH,nW = (H-1)//stride0+1,(W-1)//stride0+1 +# flows = 2*th.rand((B,T,W_t-1,2,nH,nW)).to(vid0.device) + +# # -- exec fold fxns -- +# k0 = W_t*ws*ws +# search = stnls.search.NonLocalSearch(ws, wt, ps, -1, nheads, +# dilation=dil,stride0=stride0, stride1=stride1, +# reflect_bounds=reflect_bounds,full_ws=True, +# self_action=None,use_adj=use_adj, +# dist_type=dist_type,topk_mode="all", +# itype=itype) +# k = k0*wr*wr +# stride1 = 0.1 +# refine = stnls.search.RefineSearch(ws, wt, wr, k, kr, ps, nheads, +# dilation=dil,stride0=stride0, stride1=stride1, +# reflect_bounds=reflect_bounds,full_ws=True, +# self_action=None,use_adj=use_adj, +# dist_type=dist_type,topk_mode="all",itype=itype) + +# # -- exec -- +# _dists,_inds = search(vid0,vid1,flows) +# dists,inds = refine(vid0,vid1,_inds) +# # print(_inds[0,0,0,2,2,:10]) +# # print(inds[0,0,0,2,2,:10]) + +# delta = dists[...,1:] - dists[...,:-1] +# delta = delta[~th.isnan(delta)] +# if dist_type == "l2": +# assert th.all(delta>=0).item() +# else: +# assert th.all(delta<=0).item() + diff --git a/tests/search/test_paired_search.py b/tests/search/test_paired_search.py index 2932512..c7d0c87 100644 --- a/tests/search/test_paired_search.py +++ b/tests/search/test_paired_search.py @@ -17,9 +17,7 @@ # -- stnls -- import stnls - -# -- meshgrid -- - +from stnls.testing import check_shuffled_inds # -- test func -- from torch.nn.functional import fold,unfold,pad @@ -35,45 +33,30 @@ def set_seed(seed): random.seed(seed) def pytest_generate_tests(metafunc): - test_lists = {"ps":[5],"stride0":[1,2],"stride1":[1.1], - "dilation":[1],"wt":[1,2],"ws":[1,5], - "k":[-1],"nheads":[1],"self_action":[None], - "seed":[0,1,2],"dist_type":["prod","l2"], - "itype":["int","float"],"reflect_bounds":[True]} + # seed = 123 + # th.manual_seed(seed) + # np.random.seed(seed) + test_lists = {"ws":[3],"wt":[1],"k":[-1],"pt":[1], + "ps":[3],"stride0":[1],"stride1":[1],"dilation":[1], + "self_action":["anchor_each",None],"nheads":[1],"seed":[0], + "dist_type":["l2","prod"],"itype":["int","float"], + # "dist_type":["l2"],"itype":["float"], + "reflect_bounds":[True]} for key,val in test_lists.items(): if key in metafunc.fixturenames: metafunc.parametrize(key,val) -def check_shuffled_inds(inds_gt,inds_te,eps=1e-3): - args = th.where(th.mean(th.abs(inds_gt-inds_te),dim=-1)>eps) - i0,i1 = [],[] - for i in range(3): - i0.append(inds_gt[...,i][args]) - i1.append(inds_te[...,i][args]) - i0 = th.stack(i0,-1) - i1 = th.stack(i1,-1) - idiffs = th.cdist(i0[None,:],i1[None,:])[0] - mins = th.min(idiffs,1).values - diff = th.sum(mins).item() - return diff < 1e-4 - -def test_fwd(ws,wt,k,ps,stride0,stride1,dilation, +def test_fwd(ws,wt,ps,pt,stride0,stride1,dilation, nheads,self_action,dist_type,seed,itype,reflect_bounds): # -- get args -- - dil = dilation - ext = "jpg" - dnames = ["davis_baseball_64x64","davis_baseball_64x64"] - pt = 1 - seed = 234 device = "cuda:0" - use_adj = False set_seed(seed) W_t = 2*wt+1 - k = W_t*ws*ws if k == 0 else -1 + k = W_t*ws*ws # -- load data -- - B,T,F,H,W = 2,10,16,16,8 + B,T,F,H,W = 1,5,3,16,16 vid = th.ones((B,T,F,H,W),device=device).float() vid0 = th.randn_like(vid)-0.5 vid1 = th.randn_like(vid) @@ -82,18 +65,19 @@ def test_fwd(ws,wt,k,ps,stride0,stride1,dilation, nH,nW = (H-1)//stride0+1,(W-1)//stride0+1 flows = th.ones((B,1,T,W_t-1,2,nH,nW)).cuda()#/2. flows = th.rand_like(flows)/2.+th.randint_like(flows,-2,2)+0.2 + # flows = th.zeros_like(flows) # -- exec fold fxns -- sch = stnls.search search_gt = sch.NonLocalSearch(ws, wt, ps, k, nheads, dist_type=dist_type, - dilation=dil,stride0=stride0, stride1=stride1, + dilation=dilation,stride0=stride0, stride1=stride1, reflect_bounds=reflect_bounds,full_ws=True, - self_action=self_action,use_adj=use_adj, + self_action=self_action, topk_mode="each",itype=itype) search_te = sch.PairedSearch(ws, ps, k, nheads, dist_type=dist_type, - dilation=dil,stride0=stride0, stride1=stride1, + dilation=dilation,stride0=stride0, stride1=stride1, reflect_bounds=reflect_bounds,full_ws=True, - self_action=self_action,use_adj=use_adj, + self_action=self_action, itype=itype) # -- [groundtruth] search -- @@ -135,7 +119,7 @@ def test_bwd(ws,wt,k,ps,stride0,stride1,dilation, W_t = 2*wt+1 flows = th.ones((B,1,T,W_t-1,2,nH,nW)).cuda()/2. flows = th.rand_like(flows)/2.+th.randint_like(flows,-2,2)+0.2 - # flows = th.rand_like(flows)/2.+0.2 # away from int + # flows = th.zeros_like(flows) # -- unpack image -- device = vid.device diff --git a/tests/search/test_refinement.py b/tests/search/test_refinement.py index 4f9f9d6..362b71b 100644 --- a/tests/search/test_refinement.py +++ b/tests/search/test_refinement.py @@ -19,7 +19,8 @@ import stnls import stnls.utils.gpu_mem as gpu_mem from stnls.utils.pads import comp_pads -# from stnls.utils.inds import get_batching_info +from stnls.testing.gradcheck import gradcheck_skipnan,gradcheck_skip_nan_unstable +from stnls.testing import check_shuffled_inds,int_spaced_vid # -- test func -- from torch.nn.functional import fold,unfold,pad @@ -29,83 +30,6 @@ SAVE_DIR = Path("./output/tests/non_local_search") -def gradcheck_skip_nan_unstable(fxn,inputs, rtol=1e-05, atol=1e-08, - nreps=3, num_eps=5e-4, unstable_eps=1e-2): - num = get_num_jacobian_skip_unstable(fxn,inputs,eps=num_eps, - nreps=nreps,unstable_eps=unstable_eps) - ana = get_ana_jacobian(fxn,inputs) - args = th.where(th.logical_and(~th.isnan(num),num.abs()>0)) - args1 = th.where(th.abs(num[args]-ana[args])>1e-2)[0] - # print("ana: ",ana[47,573:575]) - # print(num[:5,:5]) - # print(ana[:5,:5]) - # print(num[-5:,-5:]) - # print(ana[-5:,-5:]) - # # print(num.shape) - # print(num[args][args1][:20]) - # print(ana[args][args1][:20]) - # print([args[i][args1] for i in range(2)]) - return th.allclose(num[args],ana[args],atol=atol,rtol=rtol) - -def gradcheck_skipnan(fxn,inputs, rtol=1e-05, atol=1e-08, nreps=1, num_eps=5e-4): - num = get_num_jacobian(fxn,inputs,eps=num_eps,nreps=nreps) - ana = get_ana_jacobian(fxn,inputs) - args = th.where(th.logical_and(~th.isnan(num),num.abs()>0)) - args1 = th.where(th.abs(num[args]-ana[args])>1e-2)[0] - # print(num[-5:,-5:]) - # print(ana[-5:,-5:]) - # print(num.shape) - # print(num[args][args1][:20]) - # print(ana[args][args1][:20]) - # print([args[i][args1] for i in range(2)]) - return th.allclose(num[args],ana[args],atol=atol,rtol=rtol) - -def get_num_jacobian_skip_unstable(fxn,inputs,eps=1e-3,nreps=1,unstable_eps=1e-2): - from torch.autograd.gradcheck import _get_numerical_jacobian - nums = [] - for i in range(nreps): - eps_i = eps * (1 + i*eps) - num = _get_numerical_jacobian(fxn, (inputs,), - eps=eps_i, is_forward_ad=False)[0][0] - nums.append(num) - - delta = th.zeros_like(nums[0]) - for i in range(nreps): - # print(nums[i][47,573:575]) - for j in range(nreps): - if i >= j: continue - # print(i,j) - delta += th.abs(nums[i] - nums[j]) - # print(delta) - # print(delta[~th.isnan(delta)].min(),delta[~th.isnan(delta)].max()) - # print("Percentage unstable: ",100*th.mean(1.*(delta > unstable_eps)).item()) - unstable = th.where(th.logical_or(delta > unstable_eps,th.isnan(delta))) - num = th.mean(th.stack(nums),dim=0) - num[unstable] = th.nan - # print(num) - # print(nums[0]) - return num - -def get_num_jacobian(fxn,inputs,eps=1e-3,nreps=1): - from torch.autograd.gradcheck import _get_numerical_jacobian - num = _get_numerical_jacobian(fxn, (inputs,), - eps=eps, is_forward_ad=False)[0][0] - for i in range(nreps-1): - num += get_num_jacobian(fxn,inputs,eps=eps) - num /= nreps - return num - -def get_ana_jacobian(fxn,inputs): - from torch.autograd.gradcheck import _check_analytical_jacobian_attributes - out = fxn(inputs) - ana = _check_analytical_jacobian_attributes((inputs,), out, 1e-7, False)[0] - return ana - -def get_gradcheck_pair(fxn,inputs,eps=1e-3): - num = get_num_jacobian(fxn,inputs,eps=1e-3) - ana = get_ana_jacobian(fxn,inputs) - return num,ana - def pytest_generate_tests(metafunc): seed = 123 th.manual_seed(seed) @@ -147,6 +71,7 @@ def test_refine_fwd(ws,wt,wr,kr,k,ps,stride0,stride1,dilation, self_action = None ext = "jpg" dnames = ["davis_baseball_64x64","davis_baseball_64x64"] + set_seed(seed) # -- load data -- vid = stnls.testing.data.load_burst_batch("./data/",dnames,ext=ext) @@ -238,7 +163,6 @@ def test_refine_noshuffle_bwd(ws,wt,wr,kr,ps,stride0,stride1,dilation, vid1 = th.rand_like(vid0)/2.+0.2 # -- init for grads -- - vid0_srch,vid1_srch = vid0.clone(),vid1.clone() vid0.requires_grad_(True) vid1.requires_grad_(True) @@ -316,6 +240,7 @@ def test_anchor_fwd(ws,wt,wr,ps,stride0,stride1,dilation, dnames = ["davis_baseball_64x64","davis_baseball_64x64"] topk_mode = "each" kr = -1 + set_seed(seed) # -- load data -- vid = stnls.testing.data.load_burst_batch("./data/",dnames,ext=ext) @@ -399,6 +324,7 @@ def test_fwd_topk(ws,wt,wr,ps,stride0,stride1,dilation,dist_type,seed,reflect_bo itype = "float" nheads = 1 kr = -1 + set_seed(seed) # -- load data -- vid = stnls.testing.data.load_burst_batch("./data/",dnames,ext=ext)