Skip to content

Commit

Permalink
added paired refine skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Oct 29, 2023
1 parent e269014 commit 2e81bb5
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 87 deletions.
2 changes: 2 additions & 0 deletions lib/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
void init_non_local_search(py::module &);
void init_refinement(py::module &);
void init_paired_search(py::module &);
void init_paired_refine(py::module &);
void init_n3net_matmult1(py::module &);


Expand All @@ -23,6 +24,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
init_non_local_search(m);
init_refinement(m);
init_paired_search(m);
init_paired_refine(m);
init_n3net_matmult1(m);

// -- nn --
Expand Down
11 changes: 10 additions & 1 deletion lib/stnls/nn/anchor_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,25 @@ def run(dists,inds,stride0,H,W,qstart=0):
# -- view --
# print(dists.shape)
B,HD,Q,Ks,ws,ws = dists.shape
d2or3 = inds.shape[-1]
dshape,ishape = list(dists.shape),list(inds.shape)
dists = dists.view(B*HD,Q,Ks*ws*ws)
inds = inds.view(B*HD,Q,Ks*ws*ws,3)
inds = inds.view(B*HD,Q,Ks*ws*ws,d2or3)

# -- [patchwork] --
if d2or3 == 2:
inds = th.cat([th.zeros_like(inds[...,[0]]),inds],-1)

# -- allocate --
order = th.zeros_like(dists[...,0]).int()

# -- run --
stnls_cuda.anchor_self(dists,inds,order,stride0,H,W)

# -- [patchwork] --
if d2or3 == 2:
inds = inds[...,1:].contiguous()

# -- return --
dists = dists.reshape(dshape)
inds = inds.reshape(ishape)
Expand Down
2 changes: 2 additions & 0 deletions lib/stnls/search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"refine":"refinement",
"pair":"paired_search",
"paired":"paired_search",
"paired_refine":"paired_refine",
"paired_ref":"paired_refine",
"rand_inds":"rand_inds",
"n3mm":"n3mm_search"})

Expand Down
84 changes: 73 additions & 11 deletions lib/stnls/search/paired_bwd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
def paired_backward(ctx, grad_dists, grad_inds):

# -- populate names --
inds,vid0,vid1,flow = ctx.saved_tensors
inds,frame0,frame1,flow = ctx.saved_tensors
itype_bwd = ctx.itype
inds = get_inds(inds,itype_bwd)
grad_flow = allocate_grad_flows(itype_bwd,flow.shape,flow.device)

# -- allocate grads --
grad_vid0 = allocate_vid(ctx.vid_shape,grad_dists.device)
grad_vid1 = allocate_vid(ctx.vid_shape,grad_dists.device)
grad_frame0 = allocate_vid(ctx.vid_shape,grad_dists.device)
grad_frame1 = allocate_vid(ctx.vid_shape,grad_dists.device)
# return grad_vid0,grad_vid1,grad_flow

# -- restrict to k_agg --
Expand All @@ -49,29 +49,91 @@ def paired_backward(ctx, grad_dists, grad_inds):
if itype_bwd == "int":
bwd_fxn = stnls_cuda.paired_search_int_backward
inds = inds.view(B,HD,T*nH*nW,K,2)
bwd_fxn(grad_vid0,grad_vid1,
vid0,vid1,grad_dists,inds,
bwd_fxn(grad_frame0,grad_frame1,
frame0,frame1,grad_dists,inds,
ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds,
patch_offset,ctx.dist_type_i)
else:
bwd_fxn = stnls_cuda.paired_search_bilin2d_backward
bwd_fxn(grad_vid0,grad_vid1,grad_flow,
vid0,vid1,flow,grad_dists,grad_inds,inds,
bwd_fxn(grad_frame0,grad_frame1,grad_flow,
frame0,frame1,flow,grad_dists,grad_inds,inds,
ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds,
patch_offset,ctx.dist_type_i)

# -- finalize shape --
if ctx.in_ndim == 4:
grad_vid0 = rearrange(grad_vid0,'B H c h w -> B (H c) h w')
grad_vid1 = rearrange(grad_vid1,'B H c h w -> B (H c) h w')
grad_frame0 = rearrange(grad_frame0,'B H c h w -> B (H c) h w')
grad_frame1 = rearrange(grad_frame1,'B H c h w -> B (H c) h w')

# -- normz --
if ctx.normalize_bwd:
normz_bwd(ctx,grad_vid0,grad_vid1)
normz_bwd(ctx,grad_frame0,grad_frame1)

# -- no grad if ints --
if itype_bwd == "int":
grad_flow = None

return grad_vid0,grad_vid1,grad_flow
return grad_frame0,grad_frame1,grad_flow


def paired_refine_backward(ctx, grad_dists, grad_inds):

# -- populate names --
inds,frame0,frame1,flow = ctx.saved_tensors
itype_bwd = ctx.itype
inds = get_inds(inds,itype_bwd)
grad_flow = allocate_grad_flows(itype_bwd,flow.shape,flow.device)

# -- allocate grads --
grad_frame0 = allocate_vid(ctx.vid_shape,grad_dists.device)
grad_frame1 = allocate_vid(ctx.vid_shape,grad_dists.device)
# return grad_vid0,grad_vid1,grad_flow

# -- restrict to k_agg --
if ctx.k_agg > 0:
grad_dists = grad_dists[...,:ctx.k_agg]
inds = inds[...,:ctx.k_agg,:]

# -- ensure contiguous --
grad_dists = grad_dists.contiguous()
inds = inds.contiguous()

# -- view --
B,HD,T,nH,nW,K,_ = inds.shape
inds = inds.view(B,HD,T*nH*nW,K,2)
grad_inds = grad_inds.view(B,HD,T*nH*nW,K,2)
grad_dists = grad_dists.view(B,HD,T*nH*nW,K)
patch_offset = 0 if ctx.use_adj else -(ctx.ps//2)

# -- allow for repeated exec --
grad_inds = grad_inds.contiguous()
if itype_bwd == "int":
bwd_fxn = stnls_cuda.paired_refine_int_backward
inds = inds.view(B,HD,T*nH*nW,K,2)
bwd_fxn(grad_frame0,grad_frame1,
frame0,frame1,grad_dists,inds,
ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds,
patch_offset,ctx.dist_type_i)
else:
bwd_fxn = stnls_cuda.paired_refine_bilin2d_backward
bwd_fxn(grad_frame0,grad_frame1,grad_flow,
frame0,frame1,flow,grad_dists,grad_inds,inds,
ctx.stride0,ctx.ps,ctx.dil,ctx.reflect_bounds,
patch_offset,ctx.dist_type_i)

# -- finalize shape --
if ctx.in_ndim == 4:
grad_frame0 = rearrange(grad_frame0,'B H c h w -> B (H c) h w')
grad_frame1 = rearrange(grad_frame1,'B H c h w -> B (H c) h w')

# -- normz --
if ctx.normalize_bwd:
normz_bwd(ctx,grad_frame0,grad_frame1)

# -- no grad if ints --
if itype_bwd == "int":
grad_flow = None

return grad_frame0,grad_frame1,grad_flow


77 changes: 5 additions & 72 deletions lib/stnls/search/paired_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,11 @@ def paired_forward(frame0, frame1, flow,
nW0 = (W-1)//stride0+1
Q = nH0*nW0

# -- search space --
ws_h,ws_w = ws,ws

# -- settings from distance type --
dist_type_i,descending,idist_val = dist_type_select(dist_type)

# -- allocate results --
base_shape = (B,HD,Q,ws_h,ws_w)
base_shape = (B,HD,Q,ws,ws)
dists,inds = allocate_pair_2d(base_shape,device,frame0.dtype,idist_val,itype)
# print("inds.shape: ",inds.shape)

Expand All @@ -84,29 +81,19 @@ def paired_forward(frame0, frame1, flow,

# print(frame0.shape,frame1.shape,flow.shape,inds.shape)
# -- compress search region --
dists=dists.view(B,HD,Q,-1)
inds=inds.view(B,HD,Q,-1,2)
# th.cuda.synchronize()
# dists=dists.view(B,HD,Q,-1)
# inds=inds.view(B,HD,Q,-1,2)
# # th.cuda.synchronize()

# -- anchor --
assert self_action in [None,"anchor","anchor_each"]
anchor_self = False if self_action is None else "anchor" in self_action
if self_action is None: pass
elif "anchor" in self_action:
stnls.nn.anchor_self(dists,inds,stride0,H,W)
stnls.nn.anchor_self(dists[...,None,:,:],inds[...,None,:,:,:],stride0,H,W)
else:
raise ValueError(f"Uknown option for self_action [{self_action}]")

# # # -- manage self dists --
# # anchor_self = self_action == "anchor"
# # remove_self = self_action == "remove"
# # inds = th.cat([th.zeros_like(inds[...,[0]]),inds],-1)
# # dists,inds = manage_self(dists,inds,anchor_self,
# # remove_self,0,stride0,H,W)
# inds = inds[...,1:]
# # print(inds.shape)
# # print(inds[0,0,:,0])
# # exit()

# -- topk --
if k > 0:
Expand Down Expand Up @@ -235,60 +222,6 @@ def __init__(self, ws, ps, k, nheads=1,
def paired_vids(self, vid0, vid1, flows, wt, skip_self=False):
return _paired_vids(self.forward, vid0, vid1, flows, wt, skip_self)

# def paired_stacking(self, vid0, vid1, acc_flows, wt, stack_fxn):
# dists,inds = [],[]
# T = vid0.shape[1]
# zflow = th.zeros_like(acc_flows.fflow[:,0,0])
# for ti in range(T):
# # if ti != 1: continue

# swap = False
# t_inc = 0
# prev_t = ti
# t_shift = min(0,ti-wt) + max(0,ti + wt - (T-1))
# t_max = min(T-1,ti + wt - t_shift);
# # print(t_shift,t_max)
# tj = ti

# dists_i,inds_i = [],[]
# for _tj in range(2*wt+1):

# # -- update search frame --
# prev_t = tj
# tj = prev_t + t_inc
# swap = tj > t_max
# t_inc = 1 if (t_inc == 0) else t_inc
# t_inc = -1 if swap else t_inc
# tj = ti-1 if swap else tj
# prev_t = ti if swap else prev_t
# # print(ti,tj,t_inc,swap)

# frame0 = vid0[:,ti]
# frame1 = vid1[:,tj]
# if ti == tj:
# flow = zflow
# elif ti < tj:
# # print("fwd: ",ti,tj,tj-ti-1)
# # flow = acc_flows.fflow[:,tj - ti - 1]
# flow = acc_flows.fflow[:,ti,tj-ti-1]
# elif ti > tj:
# # print("bwd: ",ti,tj,ti-tj-1)
# # flow = acc_flows.bflow[:,ti - tj - 1]
# flow = acc_flows.bflow[:,ti,ti-tj-1]
# flow = flow.float()
# dists_ij,inds_ij = self.forward(frame0,frame1,flow)
# inds_t = (tj-ti)*th.ones_like(inds_ij[...,[0]])
# inds_ij = th.cat([inds_t,inds_ij],-1)
# dists_i.append(dists_ij)
# inds_i.append(inds_ij)
# dists_i = th.cat(dists_i,-1)
# inds_i = th.cat(inds_i,-2)
# dists.append(dists_i)
# inds.append(inds_i)
# dists = th.cat(dists,-2)
# inds = th.cat(inds,-3)
# return dists,inds

def forward(self, frame0, frame1, flow):
assert self.ws > 0,"Must have nonzero spatial search window"
return PairedSearchFunction.apply(frame0,frame1,flow,
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
'lib/csrc/search/refinement.cpp', # Search over K offsets
'lib/csrc/search/refinement_int_kernel.cu',
'lib/csrc/search/refinement_bilin2d_kernel.cu',
'lib/csrc/search/paired_search.cpp', # Space-Time Search (Pair of Frames)
'lib/csrc/search/paired_search.cpp', # Paired Search
'lib/csrc/search/paired_search_kernel.cu',
'lib/csrc/search/paired_refine.cpp', # Paired Refinement
'lib/csrc/search/paired_refine_kernel.cu',
'lib/csrc/search/mat_mult1.cpp', # Space-Time Search (Pair of Frames)
'lib/csrc/search/mat_mult1_kernel.cu',
# -- nn --
Expand Down
4 changes: 2 additions & 2 deletions tests/search/test_paired_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_fwd(ws,wt,k,ps,stride0,stride1,dilation,
device = "cuda:0"
use_adj = False
set_seed(seed)
k = ws*ws if k == 0 else -1
W_t = 2*wt+1
k = W_t*ws*ws if k == 0 else -1

# -- load data --
B,T,F,H,W = 2,10,16,16,8
Expand All @@ -79,7 +80,6 @@ def test_fwd(ws,wt,k,ps,stride0,stride1,dilation,

# -- load flows --
nH,nW = (H-1)//stride0+1,(W-1)//stride0+1
W_t = 2*wt+1
flows = th.ones((B,1,T,W_t-1,2,nH,nW)).cuda()#/2.
flows = th.rand_like(flows)/2.+th.randint_like(flows,-2,2)+0.2

Expand Down

0 comments on commit 2e81bb5

Please sign in to comment.