Skip to content

Commit

Permalink
added refine to apig
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Oct 30, 2023
1 parent 08ff6d3 commit b0b7906
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
9 changes: 6 additions & 3 deletions lib/stnls/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# -- packages --
from . import non_local_search as non_local_search_f
from . import refinement as refinement_f
from . import paired_search as paired_f
from . import paired_search as paired_search_f
from . import paired_refine as paired_refine_f
from . import n3mm_search as n3mm_search_f
from .utils import empty_flow,search_wrap
from .utils import get_time_window_inds
Expand All @@ -15,12 +16,14 @@
# -- functional api --
nls = non_local_search_f._apply
refine = refinement_f._apply
paired = paired_f._apply
paired_search = paired_search_f._apply
paired_refine = paired_refine_f._apply
n3mm = n3mm_search_f._apply

# -- class api --
NonLocalSearch = non_local_search_f.NonLocalSearch
RefineSearch = refinement_f.RefineSearch
PairedSearch = paired_f.PairedSearch
PairedSearch = paired_search_f.PairedSearch
PairedRefine = paired_refine_f.PairedRefine
N3MatMultSearch = n3mm_search_f.N3MatMultSearch

1 change: 0 additions & 1 deletion lib/stnls/search/paired_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from .utils import shape_frames,allocate_pair_2d,dist_type_select,allocate_vid
from .utils import get_ctx_shell,ensure_flow_shape,ensure_paired_flow_dim
from .shared import reflect_bounds_warning
from .paired_bwd_impl import paired_refine_backward
from .utils import paired_vids as _paired_vids

# -- implementation --
Expand Down
35 changes: 20 additions & 15 deletions tests/search/test_paired_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@


def pytest_generate_tests(metafunc):
test_lists = {"ws":[3],"wt":[1],"k":[-1],"wr":[1],"kr":[-1],"pt":[1],
test_lists = {"ws":[3],"wt":[1],"K":[-1],"wr":[1],"kr":[-1],"pt":[1],
"ps":[3],"stride0":[1],"stride1":[1],"dilation":[1],
"self_action":["anchor_each"],"nheads":[1],"seed":[0],
"dist_type":["prod","l2"],"itype":["int","float"],
"reflect_bounds":[True]}
"topk_mode":["all"],"reflect_bounds":[True]}
for key,val in test_lists.items():
if key in metafunc.fixturenames:
metafunc.parametrize(key,val)
Expand All @@ -45,18 +45,24 @@ def set_seed(seed):
np.random.seed(seed)
random.seed(seed)

def test_fwd_match_refine(ws,wt,wr,kr,k,ps,pt,stride0,stride1,dilation,
self_action,nheads,dist_type,itype,seed,reflect_bounds):
# -- load data --
def test_fwd_match_refine(ws,wt,wr,kr,K,ps,pt,stride0,stride1,dilation,
self_action,nheads,dist_type,topk_mode,itype,
seed,reflect_bounds):
# -- init --
B,HD,T,F,H,W = 1,nheads,3,1,10,10
W_t = min(T,2*wt+1)
K = W_t*ws*ws if K <= 0 else K
K_refine = int(K*kr)
device = "cuda:0"
set_seed(seed)

B,HD,T,F,H,W = 1,nheads,3,1,10,10
# -- video data --
vid = th.ones((B,T,HD*F,H,W),device=device)
vid0 = th.rand_like(vid)#.requires_grad_(True)
vid1 = th.rand_like(vid)#.requires_grad_(True)

# -- create inds --
nH,nW = (H-1)//stride0+1,(W-1)//stride0+1
flows = th.ones((B,HD,T,nH,nW,K,3))+0.1
flows = th.rand_like(flows)/2.+1.1
tgrid = th.arange(0,T).view(1,1,T,1,1,1)
Expand All @@ -69,17 +75,17 @@ def test_fwd_match_refine(ws,wt,wr,kr,k,ps,pt,stride0,stride1,dilation,
# srch_inds = srch_inds.requires_grad_(True)

# -- exec fold fxns --
refine_gt = stnls.search.RefineSearch(ws, wt, wr, k_refine, kr, ps, nheads,
refine_gt = stnls.search.RefineSearch(ws, wt, wr, K_refine, kr, ps, nheads,
dilation=dilation,
stride0=stride0, stride1=stride1,
reflect_bounds=reflect_bounds,full_ws=True,
self_action=self_action,
self_action=self_action,topk_mode=topk_mode,
dist_type=dist_type,itype=itype)
refine_te = stnls.search.PairedRefine(wr, ws, ps, k_refine, kr, nheads,
refine_te = stnls.search.PairedRefine(wr, ws, ps, K_refine, kr, nheads,
dilation=dilation,
stride0=stride0, stride1=stride1,
reflect_bounds=reflect_bounds,full_ws=True,
self_action=self_action,
self_action=self_action,topk_mode=topk_mode,
dist_type=dist_type,itype=itype)

# -- test api --
Expand All @@ -102,7 +108,7 @@ def test_fwd_match_search(ws,wt,kr,ps,pt,stride0,stride1,dilation,
# -- init vars --
device = "cuda:0"
wr = 1
W_t = 2*wt+1
W_t = min(2*wt+1,T)
k = W_t*ws*ws
set_seed(seed)

Expand All @@ -114,13 +120,12 @@ def test_fwd_match_search(ws,wt,kr,ps,pt,stride0,stride1,dilation,

# -- compute flow --
nH,nW = (H-1)//stride0+1,(W-1)//stride0+1
W_t = min(2*wt+1,T)
flows = th.ones((B,HD,T,W_t-1,2,nH,nW)).cuda()/2.
flows = th.rand_like(flows)/2.+th.randint_like(flows,-3,3)+0.2
# flows = flows.requires_grad_(True)

# -- exec fold fxns --
search = stnls.search.PairedSearch(ws, wt, ps, k, nheads,
search = stnls.search.PairedSearch(ws, wt, ps, K, nheads,
dilation=dilation,
stride0=stride0, stride1=stride1,
reflect_bounds=reflect_bounds,full_ws=True,
Expand All @@ -142,7 +147,7 @@ def test_fwd_match_search(ws,wt,kr,ps,pt,stride0,stride1,dilation,

# @pytest.mark.slow
def test_refine_noshuffle_bwd(ws,wt,wr,kr,ps,pt,stride0,stride1,dilation,
self_action,k,nheads,dist_type,itype,seed,reflect_bounds):
self_action,K,nheads,dist_type,itype,seed,reflect_bounds):
"""
Test the CUDA code with torch code
Expand All @@ -158,7 +163,7 @@ def test_refine_noshuffle_bwd(ws,wt,wr,kr,ps,pt,stride0,stride1,dilation,

# -- shapes --
W_t = 2*wt+1
k,kr = W_t*ws*ws,-1
K,kr = W_t*ws*ws,-1
HD,K = nheads,k

# -- load data --
Expand Down

0 comments on commit b0b7906

Please sign in to comment.