Skip to content

Commit

Permalink
update paired refine
Browse files Browse the repository at this point in the history
  • Loading branch information
gauenk committed Nov 25, 2023
1 parent c35683f commit 3b3778c
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 54 deletions.
44 changes: 25 additions & 19 deletions dev/named_full_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def get_unique_index(nl_hi,nl_wi,hi,wi,
ws_j = abs(num_w) - (wsHalf+1) + (wsHalf)
ws_i = num_h+wsHalf

# -- debug --
# print((hi,wi),(nl_hi,nl_wi),(oob_i,oob_j,))

# -- standard names --
if not(oob_i or oob_j):
ws_i = num_h + wsHalf
Expand All @@ -123,7 +126,7 @@ def get_tlims(ti, T, wt):
t_max = min(T-1,ti + wt - t_shift);
return t_max,t_min

def fill_names(ti,hi,wi,ki,ws,wt,stride0,stride1,st_offset,
def fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,st_offset,
full_ws,names,counts,flows_k):

# -- unpack --
Expand All @@ -133,12 +136,14 @@ def fill_names(ti,hi,wi,ki,ws,wt,stride0,stride1,st_offset,
wsOff_h,wsOff_w = wsHalf,wsHalf

# -- get non-local position --
nl_ti = ti + flows_k[ti][hi][wi][ki][0]
nl_hi = hi + flows_k[ti][hi][wi][ki][1]
nl_wi = wi + flows_k[ti][hi][wi][ki][2]
hi,wi = h_ref*stride0,w_ref*stride0
nl_ti = ti + flows_k[ti][h_ref][w_ref][ki][0]
nl_hi = hi + flows_k[ti][h_ref][w_ref][ki][1]
nl_wi = wi + flows_k[ti][h_ref][w_ref][ki][2]
valid = check_valid(nl_ti,nl_hi,nl_wi,T,H,W)
if not(valid): return
# if not((wi == 0) or (hi == 0)): return
# if not((wi == 0) and (hi == 0)): return
# if not((nl_hi == 1) and (nl_wi == 0)): return
# if not((nl_hi == 0) and (nl_wi == 2)): return
# if not((nl_hi == 0) and (nl_wi == 2)): return
Expand All @@ -154,18 +159,18 @@ def fill_names(ti,hi,wi,ki,ws,wt,stride0,stride1,st_offset,
dto = t_max - ti
si = (dt-st_offset) if (dt >= 0) else (dto - dt - st_offset)
# ws_ti = (ti+nl_ti) % W_t
ws_ti = (nl_ti+ti) % T
ws_ti = (nl_ti+ti) % T if W_t > 1 else 0
# print(si,dt,st_offset,nl_ti,ti)

# -- offset search offsets --
wsOff_h,wsOff_w = set_search_offsets(wsOff_h, wsOff_w, hi, wi,
stride1, wsHalf, ws, H, W, full_ws)

# -- get search index --
ws_i = (nl_hi - hi)//stride1 + wsOff_h
ws_j = (nl_wi - wi)//stride1 + wsOff_w
ws_i_orig = ws_i
ws_j_orig = ws_j
# ws_i = (nl_hi - hi)//stride1 + wsOff_h
# ws_j = (nl_wi - wi)//stride1 + wsOff_w
# ws_i_orig = ws_i
# ws_j_orig = ws_j

# -- handle oob --
time_offset = ws_ti*(ws*ws+2*(ws//2)*ws+(ws//2)**2)
Expand Down Expand Up @@ -200,13 +205,13 @@ def set_seed(seed):
random.seed(seed)

def main():
ws = 3
wt = 1
ws = 9
wt = 0
W_t = 2*wt+1
full_ws = True
T,H,W = 5,32,32
stride0,stride1 = 1,1
W_t_num = T#min(W_t + 2*wt,T)
T,H,W = 5,64,64
stride0,stride1 = 8,1
W_t_num = T if wt > 0 else 1#min(W_t + 2*wt,T)
S = W_t_num*(ws*ws + 2*(ws//2)*ws + (ws//2)**2)
vals = np.zeros((T,H,W,ws,ws))
names = -np.ones((S,T,H,W,3))
Expand All @@ -215,18 +220,19 @@ def main():
K = flows_k.shape[-2]
st_offset = 1
print("flows_k.shape: ",flows_k.shape)
nH,nW = (H-1)//stride0+1,(W-1)//stride0+1
for ti in range(T):
for hi in range(H):
for wi in range(W):
for h_ref in range(nH):
for w_ref in range(nW):
for ki in range(K):
fill_names(ti,hi,wi,ki,ws,wt,stride0,stride1,
fill_names(ti,h_ref,w_ref,ki,ws,wt,stride0,stride1,
st_offset,full_ws,names,counts,flows_k)

print(counts[:,0,2,2])
print(counts[:,0,:3,:3].T)
# for i in range(S):
# print(counts[i,0])
print(np.sum(counts>=0),T*H*W*K)
print(np.sum(counts==0),T*H*W*K)
print(np.sum(counts>=0),T*nH*nW*K)
print(np.sum(counts==0),T*nH*nW*K)
if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions lib/csrc/agg/scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ void scatter_tensor_forward_cuda(torch::Tensor out_tensor,
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1);
int stride0, int stride1, int H, int W);

void scatter_tensor_backward_cuda(torch::Tensor in_tensor_grad,
const torch::Tensor out_tensor_grad,
Expand Down Expand Up @@ -78,12 +78,13 @@ void scatter_tensor_forward(
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1){
int stride0, int stride1, int H, int W){
CHECK_INPUT(out_tensor);
CHECK_INPUT(in_tensor);
CHECK_INPUT(labels);
CHECK_INPUT(flows_k);
scatter_tensor_forward_cuda(out_tensor,in_tensor,labels,flows_k,stride0,stride1);
scatter_tensor_forward_cuda(out_tensor,in_tensor,labels,flows_k,
stride0,stride1,H,W);
}

void scatter_tensor_backward(
Expand Down
16 changes: 9 additions & 7 deletions lib/csrc/agg/scatter_labels_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ void get_unique_index(int& li, bool& oob,
int ws_j = -1;

// -- check spatial coordinates --
int num_h = (nl_hi - hi)/stride1;
int num_w = (nl_wi - wi)/stride1;
int num_h = abs(nl_hi - hi)/stride1;
num_h = (nl_hi >= hi) ? num_h : -num_h;
int num_w = abs(nl_wi - wi)/stride1;
num_w = (nl_wi >= wi) ? num_w : -num_w;

// -- check oob --
bool oob_i = abs(num_h) > wsHalf;
Expand Down Expand Up @@ -78,12 +80,12 @@ __global__ void scatter_labels_kernel(
int nW = flows.size(6);
int K = flows_k.size(5);
int S = names.size(2);
int H = names.size(4);
int W = names.size(5);

// -- derived --
int nHW = nH*nW;
int Q = T*nHW;
int H = nH*stride0;
int W = nW*stride0;

// -- indexing variables --
int ref_patch[3];
Expand Down Expand Up @@ -113,8 +115,8 @@ __global__ void scatter_labels_kernel(
// -- reference index --
get_pixel_loc(ref_patch,qi,stride0,nW,nHW,H,W);
int ti = ref_patch[0];
int h_ref = ref_patch[1];
int w_ref = ref_patch[2];
// int h_ref = ref_patch[1];
// int w_ref = ref_patch[2];
int hi = ref_patch[1]/stride0;
int wi = ref_patch[2]/stride0;

Expand Down Expand Up @@ -145,7 +147,7 @@ __global__ void scatter_labels_kernel(
int dt = static_cast<int>(nl_patch[0]) - ti;
int dto = t_max - ti;
int si = (dt > 0) ? (dt-st_offset) : (dto - dt - st_offset);
int ws_ti = (ref_patch[0]+nl_patch[0]) % T;
int ws_ti = (wt > 0) ? (ref_patch[0]+nl_patch[0]) % T : 0;

// -- offset reference --
// if (si >= 0){
Expand Down
14 changes: 7 additions & 7 deletions lib/csrc/agg/scatter_tensor_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ __global__ void scatter_tensor_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> in_tensor,
const torch::PackedTensorAccessor32<int,4,torch::RestrictPtrTraits> labels,
const torch::PackedTensorAccessor32<int,7,torch::RestrictPtrTraits> flows_k,
int stride0, int stride1){
int stride0, int stride1, int H, int W){

// -- unpack --
int B = flows_k.size(0);
Expand All @@ -33,10 +33,10 @@ __global__ void scatter_tensor_forward_kernel(
// -- derived --
int nHW = nH*nW;
int Q = T*nHW;
int H = nH*stride0;
int W = nW*stride0;
int nH1 = H/stride1;
int nW1 = W/stride1;
// int H = nH*stride0;
// int W = nW*stride0;
int nH1 = (H-1)/stride1+1;
int nW1 = (W-1)/stride1+1;

// -- indexing variables --
int ref_patch[3];
Expand Down Expand Up @@ -87,7 +87,7 @@ void scatter_tensor_forward_cuda(
const torch::Tensor in_tensor,
const torch::Tensor labels,
const torch::Tensor flows_k,
int stride0, int stride1){
int stride0, int stride1, int H, int W){

// -- sizes --
int B = labels.size(0);
Expand All @@ -109,7 +109,7 @@ void scatter_tensor_forward_cuda(
in_tensor.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
labels.packed_accessor32<int,4,torch::RestrictPtrTraits>(),
flows_k.packed_accessor32<int,7,torch::RestrictPtrTraits>(),
stride0,stride1);
stride0,stride1,H,W);
}));


Expand Down
9 changes: 5 additions & 4 deletions lib/stnls/agg/scatter_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
# -- cpp cuda kernel --
import stnls_cuda

def run(flows,flows_k,ws,wt,stride0,stride1,full_ws):
def run(flows,flows_k,ws,wt,stride0,stride1,H,W,full_ws):

# -- unpack shapes --
B,HD,T,nH,nW,K,_ = flows_k.shape
# B,HD,T,W_t,2,nH,nW = flows.shape
Q = T*nH*nW
W_t = 2*wt+1
H = nH*stride0
W = nW*stride0
# H = nH*stride0
# W = nW*stride0
wsHalf = (ws-1)//2

# -- number of maximum possible groups a single patch can belong to --
Wt_num = T
Wt_num = T if wt > 0 else 1
Ws_num = ws*ws
if full_ws: Ws_num += 2*ws*wsHalf + wsHalf**2
S = Wt_num*Ws_num
print(S,ws,wt,stride0,stride1,full_ws)

# -- prepare --
labels = -th.ones((B,HD,Q,K),device=flows.device,dtype=th.int)
Expand Down
19 changes: 11 additions & 8 deletions lib/stnls/agg/scatter_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,32 @@
# -- cpp cuda kernel --
import stnls_cuda

def run(tensor,flows_k,labels,stride0,stride1):
def run(tensor,flows_k,labels,stride0,stride1,H,W):

# -- unpack shapes --
B,HD,T,nH0,nW0,K = tensor.shape[:6]
Q0 = T*nH0*nW0
S = labels.max().int()+1
tensor = tensor.reshape(B,HD,Q0,K,-1)
M = tensor.shape[-1]
nH1 = (nH0*stride0)//stride1
nW1 = (nW0*stride0)//stride1
nH1 = (H-1)//stride1+1
nW1 = (W-1)//stride1+1
Q1 = T*nH1*nW1

# -- change type if needed --
dtype = tensor.dtype
if tensor.dtype == th.int:
if dtype in [th.int32,th.int64]:
tensor = tensor.float()

# -- prepare --
scatter_tensor = -th.inf*th.ones((B,HD,Q1,S,M),device=labels.device,dtype=th.float)
stnls_cuda.scatter_tensor_forward(scatter_tensor,tensor,labels,flows_k,stride0,stride1)
shape = (B,HD,Q1,S,M)
scatter_tensor = -th.inf*th.ones(shape,device=labels.device,dtype=tensor.dtype)
stnls_cuda.scatter_tensor_forward(scatter_tensor,tensor,labels,flows_k,
stride0,stride1,H,W)

# -- adjust output type --
if dtype == th.int:
tensor = tensor.int()
if dtype in [th.int32,th.int64]:
scatter_tensor = scatter_tensor.type(dtype)

# -- squeeze single M --
if M == 1:
Expand Down Expand Up @@ -100,5 +102,6 @@ def run_topk(weights,flows_k,K,descending=True):
# -- unpack --
weights = rearrange(weights,'(b hd q) k -> b hd q k',b=B,hd=HD)
flows_topk = rearrange(flows_topk,'(b hd q) k tr -> b hd q k tr',b=B,hd=HD)
flows_topk = flows_topk.type(flows_k.dtype)

return weights,flows_topk
6 changes: 3 additions & 3 deletions lib/stnls/search/impl/paired_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(frame0, frame1, flow,
# -- topk --
assert self_action in [None,"anchor","anchor_each"]
anchor_self = False if self_action is None else "anchor" in self_action
if topk_mode == "all":
if topk_mode == "all" and (k > 0):
dim = 3
dists=dists.view(B,HD,Q,-1)
inds=inds.view(B,HD,Q,-1,2)
Expand All @@ -113,14 +113,14 @@ def forward(frame0, frame1, flow,
kselect = kselect.view(B,HD,Q,Ks*wr*wr) if not(kselect is None) else kselect
# print("kselect.shape: ",kselect.shape,order.shape)
kselect = stnls.nn.topk_f.apply_topk(kselect,order,dim)
elif topk_mode == "each":
elif topk_mode == "each" and (k > 0):
dists = rearrange(dists,'... wh ww -> ... (wh ww)')
inds = rearrange(inds,'... wh ww d2or3 -> ... (wh ww) d2or3')
dists,inds = stnls.nn.topk_each(dists,inds,k,descending,anchor_self=anchor_self)
if kselect.ndim > 1 and k > 0:
kselect = rearrange(kselect,'... wh ww -> ... (wh ww)')
kselect = kselect[...,:k] # all same across dim
else:
elif (k > 0):
raise ValueError(f"Unknown topk_mode [{topk_mode}]")
# print("[post]: ",dists.shape,inds.shape)

Expand Down
4 changes: 2 additions & 2 deletions lib/stnls/search/paired_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _apply(frame0, frame1, flow,
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def extract_config(cfg,restrict=True):
pairs = {"wr":1,"ws":-1,"ps":3,"k":10,
pairs = {"ws":-1,"wr":1,"k":10,"kr":-1,"ps":1,
"nheads":1,"dist_type":"l2",
"stride0":1, "stride1":1, "dilation":1,
"restricted_radius":False,
Expand All @@ -221,7 +221,7 @@ def extract_config(cfg,restrict=True):

def init(cfg):
cfg = extract_config(cfg,False)
search = PairedRefine(cfg.wr, cfg.ws, cfg.ps, cfg.k, nheads=cfg.nheads,
search = PairedRefine(cfg.ws, cfg.wr, cfg.k, cfg.kr, cfg.ps, nheads=cfg.nheads,
dist_type=cfg.dist_type, stride0=cfg.stride0,
stride1=cfg.stride1, dilation=cfg.dilation,
restricted_radius=cfg.restricted_radius,
Expand Down
2 changes: 1 addition & 1 deletion lib/stnls/search/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _apply(vid0, vid1, flows,
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

def extract_config(cfg,restrict=True):
pairs = {"ws":-1,"wt":-1,"ps":1,"k":10,"wr":1,"kr":-1,
pairs = {"ws":-1,"wt":-1,"wr":1,"ps":1,"k":10,"kr":-1,
"nheads":1, "stride0":4, "stride1":1, "dilation":1, "pt":1,
"dist_type":"l2", "restricted_radius":False,
"reflect_bounds":True, "full_ws":True,
Expand Down
10 changes: 10 additions & 0 deletions lib/stnls/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def get_space_grid(H,W,dtype=th.float,device="cuda"):
# assert three == 3,"Must be three."
# return -flows_k

def reflect_inds(inds,H,W):
def reflect_bounds(flow,i,L):
args0 = th.where(flow[...,i] > (L-1))
args1 = th.where(flow[...,i] < 0)
flow[...,i][args0] = 2*(L-1) - flow[...,i][args0]
flow[...,i][args1] = -flow[...,i][args1]
# -- reflect --
reflect_bounds(inds,1,H)
reflect_bounds(inds,2,W)

def flow2inds(flow,stride0):
device = flow.device
B = flow.shape[0]
Expand Down

0 comments on commit 3b3778c

Please sign in to comment.