-
Notifications
You must be signed in to change notification settings - Fork 153
/
Copy pathvisual_frontend.py
1392 lines (1130 loc) · 68.1 KB
/
visual_frontend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
from abc import abstractclassmethod
from collections import OrderedDict
from icecream import ic
import cv2
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch_scatter import scatter_sum
from utils.flow_viz import *
import networks.geom.projective_ops as pops
from networks.modules.corr import CorrBlock, AltCorrBlock
import lietorch
from lietorch import SE3
import droid_backends
import gtsam
from gtsam import (HessianFactor)
from gtsam import Values
from gtsam import (Pose3, Rot3, Point3)
from gtsam import PriorFactorPose3
from gtsam import NonlinearFactorGraph
from gtsam import GaussianFactorGraph
from gtsam.symbol_shorthand import X
# utility functions for scattering ops
def safe_scatter_add_mat(A, ii, jj, n, m):
v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)
def lietorch_pose_to_gtsam(pose : lietorch.SE3):
trans, quat = pose.vec().split([3,4], -1)
trans = trans.cpu().numpy()
quat = quat.cpu().numpy()
return Pose3(Rot3.Quaternion(quat[3], quat[0], quat[1], quat[2]), Point3(trans))
def gtsam_pose_to_torch(pose: gtsam.Pose3, device, dtype):
t = pose.translation()
q = pose.rotation().quaternion()
return torch.tensor([t[0], t[1], t[2], q[1], q[2], q[3], q[0]], device=device, dtype=dtype)
class VisualFrontend(nn.Module):
def __init__(self):
super().__init__()
@abstractclassmethod
def forward(self, mini_batch):
pass
from networks.modules.extractor import BasicEncoder
from networks.droid_net import UpdateModule
class RaftVisualFrontend(VisualFrontend):
def __init__(self, world_T_body_t0, body_T_cam0, args, device="cpu"):
super().__init__()
self.args = args
self.kf_idx = 0 # Keyframe index
self.kf_idx_to_f_idx = {} # Keyframe index to frame index
self.f_idx_to_kf_idx = {} # Frame index to keyframe index
self.last_kf_idx = 0
self.last_k = None
self.global_ba = False
self.stop = False # stop module
self.compute_covariances = True
self.last_state = gtsam.Values()
self.initial_x0 = None
self.initial_priors = None
self.factors_to_remove = gtsam.KeyVector()
self.buffer = args.buffer
self.stereo = args.stereo
self.device = device
self.is_initialized = False
self.keyframe_warmup = 8
self.max_age = 25
self.max_factors = 48
self.kf_init_count = 8
self.motion_filter_thresh = 2.4 # To determine if we are moving, how much mean optical flow before considering new frame [px]
self.viz = False # Whether to visualize the results
self.world_T_body_t0 = world_T_body_t0
self.body_t0_T_world = gtsam_pose_to_torch(self.world_T_body_t0.inverse(), self.device, torch.float)
self.body_T_cam0 = body_T_cam0
self.world_T_cam0_t0 = world_T_body_t0 * body_T_cam0
self.cam0_t0_T_world = gtsam_pose_to_torch(
self.world_T_cam0_t0.inverse(), self.device, torch.float)
self.cam0_T_body = gtsam_pose_to_torch(
body_T_cam0.inverse(), self.device, torch.float)
# Frontend params
self.keyframe_thresh = 4.0 # Distance to consider a keyframe, threshold to create a new keyframe [m] # why not 0.4!
self.frontend_thresh = 16.0 # Add edges between frames within this distance
self.frontend_window = 25 # frontend optimization window
self.frontend_radius = 2 # force edges between frames within radius
self.frontend_nms = 1 # non-maximal supression of edges
self.beta = 0.3 # weight for translation / rotation components of flow # also used in backend
# Backend params
self.backend_thresh = 22.0
self.backend_radius = 2
self.backend_nms = 3
self.iters1 = 4 # number of iterations for first optimization
self.iters2 = 2 # number of iterations for second optimization
# DownSamplingFactor: resolution of the images with respect to the features extracted.
# 8.0 means that the features are at 1/8th of the original resolution.
self.dsf = 8 # perhaps the most important parameter
# Type of correlation computation to use: "volume" or "alt"
# "volume" takes a lot of memory (but is faster), "alt" takes less memory and should be as fast as volume but it's not
self.corr_impl = "volume"
# Build Networks
self.feature_net = BasicEncoder(output_dim=128, norm_fn='instance')
self.context_net = BasicEncoder(output_dim=256, norm_fn='none')
self.update_net = UpdateModule()
# Load network weights
weights = self.load_weights(args.weights)
missing_keys = self.load_state_dict(weights)
self.to(device)
self.eval()
# Uncertainty sigmas, initial sigmas for initialization (but not priors?)
self.translation_sigma = torch.tensor(0.01, device=self.device) # standard deviation of translation [m]
self.rotation_sigma = torch.tensor(0.01, device=self.device) # standard deviation of rotation [rad]
# TODO: given that the values are much larger than 1.0... we should increase this much more...
self.sigma_idepth = torch.tensor(0.1, device=self.device) # standard deviation of depth [m] (or inverse depth?) [1/m], we don't know the scale anyway...
self.t_cov = torch.pow(self.translation_sigma, 2) * torch.eye(3, device=self.device)
self.r_cov = torch.pow(self.rotation_sigma, 2) * torch.eye(3, device=self.device)
self.idepth_prior_cov = torch.pow(self.sigma_idepth, 2)
self.g_prior_cov = torch.block_diag(self.r_cov, self.t_cov) # GTSAM convention, rotation first, then translation
def __del__(self):
print("Calling frontend dtor...")
torch.cuda.empty_cache()
def stop_condition(self):
return self.stop
# Pre-allocate all the memory in the GPU
def initialize_buffers(self, image_size):
self.img_height = h = image_size[0]
self.img_width = w = image_size[1]
ic(self.dsf)
ic(h)
ic(w)
ic(h//self.dsf)
ic(w//self.dsf)
self.coords0 = pops.coords_grid(h//self.dsf, w//self.dsf, device=self.device)
self.ht, self.wd = self.coords0.shape[:2]
### Input attributes ###
self.cam0_timestamps = torch.zeros(self.buffer, dtype=torch.float, device=self.device).share_memory_()
# TODO this should be in the euroc parser, so that we don't allocate memory without bounds
self.cam0_images = torch.zeros(self.buffer, 3, h, w, dtype=torch.uint8, device=self.device).share_memory_() # TODO why not shared memory? # This is a looot of memory
self.cam0_intrinsics = torch.zeros(self.buffer, 4, dtype=torch.float, device=self.device).share_memory_()
self.gt_poses = torch.zeros(self.buffer, 4, 4, dtype=torch.float, device=self.device).share_memory_()
self.gt_depths = torch.zeros(self.buffer, 1, h, w, dtype=torch.float, device=self.device).share_memory_()
### State attributes ###
self.cam0_T_world = torch.zeros(self.buffer, 7, dtype=torch.float, device=self.device).share_memory_()
self.world_T_body = torch.zeros(self.buffer, 7, dtype=torch.float, device=self.device).share_memory_()
self.world_T_body_cov = torch.zeros(self.buffer, 6, 6, dtype=torch.float, device=self.device).share_memory_()
self.cam0_idepths = torch.ones(self.buffer, h//self.dsf, w//self.dsf, dtype=torch.float, device=self.device).share_memory_()
self.cam0_idepths_cov = torch.ones(self.buffer, h//self.dsf, w//self.dsf, dtype=torch.float, device=self.device).share_memory_()
self.cam0_depths_cov = torch.ones(self.buffer, h//self.dsf, w//self.dsf, dtype=torch.float, device=self.device).share_memory_()
self.cam0_idepths_sensed = torch.zeros(self.buffer, h//self.dsf, w//self.dsf, dtype=torch.float, device=self.device).share_memory_()
self.cam0_idepths_up = torch.zeros(self.buffer, h, w, dtype=torch.float, device=self.device).share_memory_() # This is a looot of memory
self.cam0_depths_cov_up = torch.ones(self.buffer, h, w, dtype=torch.float, device=self.device).share_memory_() # This is a looot of memory
# INITIALIZE state:
# - poses all to initial state transformation
# - velocities all to 0 except first to initial state
# - biases all to 0 except first to initial state
# TODO: why not shared memory?
self.cam0_T_world[:] = self.cam0_t0_T_world
self.world_T_body[:] = gtsam_pose_to_torch(self.world_T_body_t0, device=self.device, dtype=torch.float)
self.world_T_body_cov[:] = self.g_prior_cov * torch.eye(6, device=self.device) #* torch.as_tensor([0.001, 0.001, 0.001, 0.001, 0.001, 0.001], device=self.device)[None]
self.cam0_idepths_cov *= self.idepth_prior_cov
# For multi-view, we could set this to >2, but we need to know overlapping FOVs
# we could do that automatically by looking at the mean flow between frames
cameras = 2 if self.stereo else 1
### Feature attributes ### Every keyframe has a feature/context/gru_input
self.features_imgs = torch.zeros(self.buffer, cameras, 128, h//self.dsf, w//self.dsf, dtype=torch.half, device=self.device)#.share_memory_()
self.contexts_imgs = torch.zeros(self.buffer, cameras, 128, h//self.dsf, w//self.dsf, dtype=torch.half, device=self.device)#.share_memory_()
self.cst_contexts_imgs = torch.zeros(self.buffer, cameras, 128, h//self.dsf, w//self.dsf, dtype=torch.half, device=self.device)#.share_memory_()
### Correlations, Flows, and Hidden States ### Every pair of co-visible keyframes has a correlation volume, flow, and hidden state
# These are created on-the-fly so we can't really pre-allocate memory
self.correlation_volumes = None
self.gru_hidden_states = None # initialized as context, but evolves as hidden state
self.gru_contexts_input = None # initialized as context, and remains as such
self.gru_estimated_flow = torch.zeros([1, 0, h//self.dsf, w//self.dsf, 2], device=self.device, dtype=torch.float)
self.gru_estimated_flow_weight = torch.zeros([1, 0, h//self.dsf, w//self.dsf, 2], device=self.device, dtype=torch.float)
self.damping = 1e-6 * torch.ones_like(self.cam0_idepths) # not sure what this does
### Co-visibility Graph ###
self.ii = torch.as_tensor([], dtype=torch.long, device=self.device)
self.jj = torch.as_tensor([], dtype=torch.long, device=self.device)
self.age = torch.as_tensor([], dtype=torch.long, device=self.device)
# inactive factors
self.ii_inactive = torch.as_tensor([], dtype=torch.long, device=self.device)
self.jj_inactive = torch.as_tensor([], dtype=torch.long, device=self.device)
self.ii_bad = torch.as_tensor([], dtype=torch.long, device=self.device)
self.jj_bad = torch.as_tensor([], dtype=torch.long, device=self.device)
self.gru_estimated_flow_inactive = torch.zeros([1, 0, h//self.dsf, w//self.dsf, 2], device=self.device, dtype=torch.float)
self.gru_estimated_flow_weight_inactive = torch.zeros([1, 0, h//self.dsf, w//self.dsf, 2], device=self.device, dtype=torch.float)
# For visualization, True: needs viz update, False: not changed
self.viz_idx = torch.zeros(self.buffer, device=self.device, dtype=torch.bool)
#@abstractclassmethod
def forward(self, batch):
# The output of RaftVisualFrontend is not dense optical flow
# but rather a bunch of pose-to-pose factors resulting from the reduced camera matrix.
print("RaftVisualFrontend.forward")
k = batch["k"][0]
# The output
x0 = Values()# None #
factors = NonlinearFactorGraph()# None #
viz_out = None
imgs_k = torch.as_tensor(batch["images"], device=self.device)[None].permute(0, 1, 4, 2, 3)#.shared_memory()
imgs_norm_k = self._normalize_imgs(imgs_k)
if self.viz:
for i, img in enumerate(imgs_k[0]):
cv2.imshow(f'Img{i} input', img.permute(1,2,0).cpu().numpy())
for i, img in enumerate(imgs_norm_k[0]):
cv2.imshow(f'Img{i} normalized', img.permute(1,2,0).cpu().numpy())
if self.last_k is None:
ic(k)
assert k == 0
assert self.kf_idx == 0
assert self.last_kf_idx == 0
# Initialize network buffers
self.initialize_buffers(imgs_k.shape[-2:]) # last two dims are h,w
self.gt_poses[self.kf_idx] = torch.tensor(batch["poses"][0], device=self.device)
if batch["depths"][0] is not None:
self.gt_depths[self.kf_idx] = torch.tensor(batch["depths"][0], device=self.device).permute(2,0,1)
self.cam0_timestamps[self.kf_idx] = torch.tensor(batch["t_cams"][0], device=self.device)
self.cam0_images[self.kf_idx] = torch.tensor(batch["images"][0], device=self.device)[..., :3].permute(2,0,1)
self.cam0_intrinsics[self.kf_idx] = (1.0 / self.dsf) * torch.tensor(batch["calibs"][0].camera_model.numpy(), device=self.device)
# Initialize the state
# Compute its dense features for next iteration
self.features_imgs[self.kf_idx] = self.__feature_encoder(imgs_norm_k)
# Compute its context features for next iteration
self.contexts_imgs[self.kf_idx], self.cst_contexts_imgs[self.kf_idx] = self.__context_encoder(imgs_norm_k)
# Store things for next iteration
self.last_k = k
self.last_kf_idx = self.kf_idx
self.kf_idx_to_f_idx[self.kf_idx] = k
self.f_idx_to_kf_idx[k] = self.kf_idx
viz_out = self.get_viz_out(batch)
self.kf_idx += 1
return x0, factors, viz_out
assert k > 0
assert self.kf_idx < self.buffer
# Add frame as keyframe if we have enough motion, otherwise discard:
current_imgs_features = self.__feature_encoder(imgs_norm_k)
if not self.has_enough_motion(current_imgs_features):
if batch["is_last_frame"]:
self.kf_idx -= 1 # Because in the last iter we increased it, but aren't taking any...
print("Last frame reached, and no new motion: starting GLOBAL BA")
self.terminate()
# Send the whole viz_out to update the fact that BA has changed all poses/depths
viz_out = self.get_viz_out(batch)
return x0, factors, viz_out
# By returning, we do not increment self.kf_idx
return x0, factors, viz_out
# Ok, we got enough motion, consider this frame as a keyframe
# Compute dense optical flow
self.gt_poses[self.kf_idx] = torch.tensor(batch["poses"][0], device=self.device)
if batch["depths"][0] is not None:
self.gt_depths[self.kf_idx] = torch.tensor(batch["depths"][0], device=self.device).permute(2,0,1)
self.cam0_timestamps[self.kf_idx] = torch.tensor(batch["t_cams"][0], device=self.device)
self.cam0_images[self.kf_idx] = torch.tensor(batch["images"][0], device=self.device)[..., :3].permute(2,0,1)
self.cam0_intrinsics[self.kf_idx] = (1.0 / self.dsf) * torch.tensor(batch["calibs"][0].camera_model.numpy(), device=self.device)
self.features_imgs[self.kf_idx] = current_imgs_features
self.contexts_imgs[self.kf_idx], self.cst_contexts_imgs[self.kf_idx] = self.__context_encoder(imgs_norm_k)
self.kf_idx_to_f_idx[self.kf_idx] = k
self.f_idx_to_kf_idx[k] = self.kf_idx
# Build the flow graph: ii -> jj edges
# TODO: for now just do a chain
# Just adds the `r' sequential frames to the graph
# do initialization
if not self.is_initialized:
if self.kf_idx >= self.keyframe_warmup:
self.__initialize()
else:
# We don't return here so that we increment kf_idx
# ic("Warming up: pre-processing frame with enough motion")
pass
# do update
else:
if not self.__update():
# We did not accept this keyframe, reinit its properties (really needed? they will be overwritten no?)
# Remove as well factors connected to it used for estimating its distance...
self.rm_keyframe(self.kf_idx - 1) # TODO: the -1 here changed the whole behavior, check if it is correct
# Decrease kf_idx since we are literally removing the keyframe...
# But this means that we need to keep track the difference btw the kf_idx in the backend
# and the keyframe_idx in the frontend.
#self.video.counter.value -= 1
# self.kf_idx -= 1 # so that on the next pass we use the previous keyframe
# By returning, we do not increment self.kf_idx
return x0, factors, viz_out
#x0.insert(X(k), pose_to_gtsam(last_pose))
self.last_k = k
self.last_kf_idx = self.kf_idx
self.kf_idx_to_f_idx[self.kf_idx] = k # not really necessary I think
self.f_idx_to_kf_idx[k] = self.kf_idx
viz_out = self.get_viz_out(batch) # build viz_out after updating self.kf_idx_to_f_idx
if self.viz:
cv2.waitKey(1)
if self.kf_idx + 1 >= self.buffer or batch["is_last_frame"]:
print("Buffer full or last frame reached: starting GLOBAL BA")
self.terminate()
viz_out = self.get_viz_out(batch)
return x0, factors, viz_out
self.kf_idx += 1
return x0, factors, viz_out
# If kf0 is None, then it is init to keyframe number 1 or min(ii)+1, TODO: why not 0?
# If kf1 is None, then it is init to the max(ii, jj) +1, TODO: again, why the +1?
@torch.cuda.amp.autocast(enabled=True)
def update(self, kf0=None, kf1=None, itrs=2, use_inactive=False, EP=1e-7, motion_only=False):
""" run update operator on factor graph """
#ic("Memory usage before update: {} Mb".format(torch.cuda.memory_allocated()/1024**2))
# motion features
with torch.cuda.amp.autocast(enabled=False): # try mixed precision?
# Coords1 shape is: (batch_size, num_edges, ht, wd, 2)
coords1, mask, (Ji, Jj, Jz) = self.reproject(self.ii, self.jj, cam_T_body=self.cam0_T_body, jacobian=True) # this is not using the cuda kernels... # mask is not used...
# "coords1 - coords0": coords1 = coords0 + flow, so this is the current
# flow induced by the estimated pose/depth.
# "target - coords1": residual from current
# - `measured` flow (target) by GRU: target = coords1 + flow_delta (from GRU), and
# - `estimated' flow (coords1): coords1 = coords0 + reproject()
motion = torch.cat([coords1 - self.coords0, self.gru_estimated_flow - coords1], dim=-1)
motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0)
# correlation features
corr = self.correlation_volumes(coords1)
# We then pool the hidden state over all features which share the same source view i and predict a
# pixel-wise damping factor λ. We use the softplus operator to ensure that the damping term is positive.
# Additionally, we use the pooled features to predict a 8x8 mask which can be used to upsample the
# inverse depth estimate
self.gru_hidden_states, flow_delta, gru_estimated_flow_weight, damping, upmask = \
self.update_net(self.gru_hidden_states, self.gru_contexts_input,
corr, flow=motion, ii=self.ii, jj=self.jj)
if kf0 is None:
# In droid: kf0 = max(1, self.ii.min().item()+1) # It is max(1, min(ii)) because the first pose is fixed.
kf0 = max(0, self.ii.min().item())
else:
ic(kf0)
raise
with torch.cuda.amp.autocast(enabled=False):
# flow_delta shape is: (batch_size, num_edges, ht, wd, 2)
self.gru_estimated_flow = coords1 + flow_delta.to(dtype=torch.float)
self.gru_estimated_flow_weight = gru_estimated_flow_weight.to(dtype=torch.float)
self.damping[torch.unique(self.ii)] = damping # TODO What is this damping? See also `damping` below
if use_inactive: # It is always set to True...
# TODO What is this doing? I think this is somehow setting the priors!! Or
# the marginalization priors...
# The thing is that, there are two considerations:
# - depth_maps to optimize over (K = len(torch.unique(ii)))
# - keyframe poses to optimize over (P = kf1-kf0)
# where K > P, so that there are some poses we do not optimize over (but we do optimize their depth-maps)
mask = (self.ii_inactive >= kf0 - 3) & (self.jj_inactive >= kf0 - 3)
ii = torch.cat([self.ii_inactive[mask], self.ii], 0)
jj = torch.cat([self.jj_inactive[mask], self.jj], 0)
gru_estimated_flow = torch.cat([self.gru_estimated_flow_inactive[:,mask], self.gru_estimated_flow], 1)
gru_estimated_flow_weight = torch.cat([self.gru_estimated_flow_weight_inactive[:,mask], self.gru_estimated_flow_weight], 1)
else:
ii, jj, gru_estimated_flow, gru_estimated_flow_weight = self.ii, self.jj, self.gru_estimated_flow, self.gru_estimated_flow_weight
damping = .2 * self.damping[torch.unique(ii)].contiguous() + EP # TODO What is this damping?
# gru_estimated_flow(_weight) shape after this line is: (num_edges, 2, ht, wd)
gru_estimated_flow = gru_estimated_flow.view(-1, self.ht, self.wd, 2).permute(0,3,1,2).contiguous()
gru_estimated_flow_weight = gru_estimated_flow_weight.view(-1, self.ht, self.wd, 2).permute(0,3,1,2).contiguous()
# TODO We should output at this point the GRU estimated flow and weights.
# Or rather the factors?
# Dense bundle adjustment
ic("BA!")
x0, rcm_factor = self.ba(gru_estimated_flow, gru_estimated_flow_weight, damping,
ii, jj, kf0, kf1, itrs=itrs, lm=1e-4, ep=0.1,
motion_only=motion_only, compute_covariances=self.compute_covariances)
# Stores depths_up, depths_cov_up
kx = torch.unique(self.ii)
self.cam0_idepths_up[kx] = cvx_upsample(self.cam0_idepths[kx].unsqueeze(-1), upmask).squeeze()
self.cam0_depths_cov_up[kx] = cvx_upsample(self.cam0_depths_cov[kx].unsqueeze(-1), upmask, pow=1.0).squeeze()
if self.viz:
viz_idepth(self.cam0_idepths[kx], upmask)
viz_idepth_sigma(self.cam0_idepths_cov[kx], upmask, fix_range=True, bg_img=self.cam0_images[kx])
viz_depth_sigma(self.cam0_depths_cov_up[kx].unsqueeze(-1).sqrt(), fix_range=True, bg_img=self.cam0_images[kx], sigma_thresh=20.0)
viz_flow("gru_flow", gru_estimated_flow[-1] - self.coords0.permute(2,0,1))
reprojection_flow = (coords1 - self.coords0).squeeze().permute(0,3,1,2)
viz_flow("reprojection_flow", reprojection_flow[-1])
# Viz weight
if self.viz:
# Visualize input image as well
#for k, i in enumerate(self.ii):
# self.viz_weight(gru_estimated_flow_weight[k], self.cam0_images[ii[k]]) #self.ii[-1].item(), self.jj[-1].item())
self.viz_weight(gru_estimated_flow_weight[-1], self.cam0_images[ii[-1]]) #self.ii[-1].item(), self.jj[-1].item())
# Update visualization
kf1 = max(ii.max().item(), jj.max().item())
assert kf1 == self.kf_idx
self.viz_idx[kf0:self.kf_idx+1] = True
self.age += 1
return x0, rcm_factor
@torch.cuda.amp.autocast(enabled=False)
def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, steps=8):
""" run update operator on factor graph - reduced memory implementation """
# alternate corr implementation
kfs, cameras, ch, ht, wd = self.features_imgs.shape
corr_op = AltCorrBlock(self.features_imgs.view(1, kfs*cameras, ch, ht, wd))
for step in range(steps):
print(f"Global BA Iteration #{step}/{steps}")
with torch.cuda.amp.autocast(enabled=False):
coords1, mask, _ = self.reproject(self.ii, self.jj)
motion = torch.cat([coords1 - self.coords0, self.gru_estimated_flow - coords1], dim=-1)
motion = motion.permute(0,1,4,2,3).clamp(-64.0, 64.0)
# CONVGRU RUNS
# Optimize the flow as much as possible,
s = 8
for i in range(0, self.jj.max() + 1, s): # what does this do?
print(f"ConvGRU Iteration #{i/s}/{(self.jj.max() + 1)/s}")
v = (self.ii >= i) & (self.ii < i + s) # kind-of like a sliding optimization window
iis = self.ii[v]
jjs = self.jj[v]
corr = corr_op(coords1[:,v], cameras * iis, cameras * jjs + (iis == jjs).long())
with torch.cuda.amp.autocast(enabled=True):
# TODO: somehow the damping and upmask have weird shapes... what is going on?
gru_hidden_states, flow_delta, gru_estimated_flow_weight, damping, upmask = \
self.update_net(self.gru_hidden_states[:, v], self.gru_contexts_input[:, iis], corr, motion[:, v], iis, jjs)
kx = torch.unique(iis)
all_kf_ids = torch.unique(torch.cat([iis, jjs], 0))
self.gru_hidden_states[:,v] = gru_hidden_states
self.gru_estimated_flow[:,v] = coords1[:,v] + flow_delta.float()
self.gru_estimated_flow_weight[:,v] = gru_estimated_flow_weight.float()
self.damping[all_kf_ids] = damping # TODO What is this damping? See also `damping` below
# Stores depths_up, depths_cov_up
self.cam0_idepths_up[all_kf_ids] = cvx_upsample(self.cam0_idepths[all_kf_ids].unsqueeze(-1), upmask).squeeze()
self.cam0_depths_cov_up[all_kf_ids] = cvx_upsample(self.cam0_depths_cov[all_kf_ids].unsqueeze(-1), upmask, pow=1.0).squeeze()
#ii = torch.cat([self.ii_inactive[mask], self.ii], 0)
damping = .2 * self.damping[torch.unique(self.ii)].contiguous() + EP
gru_estimated_flow = self.gru_estimated_flow.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
gru_estimated_flow_weight = self.gru_estimated_flow_weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous()
# dense bundle adjustment
ic("Global BA!")
# TODO: do not compute cov for global BA until we fix the memory issue when building Eiz
x0, rcm_factor = self.ba(gru_estimated_flow, gru_estimated_flow_weight, damping, self.ii, self.jj,
kf0=0, kf1=None, itrs=itrs, lm=1e-5, ep=1e-2,
motion_only=False, compute_covariances=False)
@torch.cuda.amp.autocast(enabled=True)
def rm_keyframe(self, kf_idx):
""" drop nodes from factor graph """
# TODO: how does this work if we rm a keyframe that is not the one before the last one??
# As of now, kf_idx is the last one, and so kf_idx+1 is all 0s.
self.gt_poses[kf_idx] = self.gt_poses[kf_idx+1]
self.gt_depths[kf_idx] = self.gt_depths[kf_idx+1]
self.cam0_images[kf_idx] = self.cam0_images[kf_idx+1]
self.cam0_timestamps[kf_idx] = self.cam0_timestamps[kf_idx+1]
self.cam0_T_world[kf_idx] = self.cam0_T_world[kf_idx+1]
self.world_T_body[kf_idx] = self.world_T_body[kf_idx+1]
self.world_T_body_cov[kf_idx] = self.world_T_body_cov[kf_idx+1]
self.cam0_idepths[kf_idx] = self.cam0_idepths[kf_idx+1]
self.cam0_idepths_cov[kf_idx] = self.cam0_idepths_cov[kf_idx+1]
self.cam0_depths_cov[kf_idx] = self.cam0_depths_cov[kf_idx+1]
self.cam0_idepths_sensed[kf_idx] = self.cam0_idepths_sensed[kf_idx+1]
self.cam0_intrinsics[kf_idx] = self.cam0_intrinsics[kf_idx+1]
self.features_imgs[kf_idx] = self.features_imgs[kf_idx+1]
self.contexts_imgs[kf_idx] = self.contexts_imgs[kf_idx+1]
self.cst_contexts_imgs[kf_idx] = self.cst_contexts_imgs[kf_idx+1]
# Remove all inactive edges that are connected to the keyframe
mask = (self.ii_inactive == kf_idx) | (self.jj_inactive == kf_idx)
# Reindex the inactive edges that we are going to keep
self.ii_inactive[self.ii_inactive >= kf_idx] -= 1
self.jj_inactive[self.jj_inactive >= kf_idx] -= 1
# Remove the inactive edges concerning this keyframe
if torch.any(mask):
self.ii_inactive = self.ii_inactive[~mask]
self.jj_inactive = self.jj_inactive[~mask]
self.gru_estimated_flow_inactive = self.gru_estimated_flow_inactive[:,~mask]
self.gru_estimated_flow_weight_inactive = self.gru_estimated_flow_weight_inactive[:,~mask]
# Remove all edges that are connected to the keyframe
mask = (self.ii == kf_idx) | (self.jj == kf_idx)
# Reindex the edges that we are going to keep
self.ii[self.ii >= kf_idx] -= 1
self.jj[self.jj >= kf_idx] -= 1
# Remove the data concerning this keyframe (correlation volumes, etc.)
self.rm_factors(mask, store=False)
def __update(self):
""" add edges, perform update """
# self.count += 1 # TODO
#self.kf_idx += 1 # TODO I think this is our kf_idx
if self.correlation_volumes is not None:
# Really only drops edges...
#ic("Removing factors, and storing.")
self.rm_factors(self.age > self.max_age, store=True)
# TODO: unclear how this works
# t = self.kf_idx
# ix = torch.arange(kf0, t)
# jx = torch.arange(kf1, t)
#ic("Adding proximity factors")
self.add_proximity_factors(kf0=self.kf_idx - 4,
kf1=max(self.kf_idx + 1 - self.frontend_window, 0),
rad=self.frontend_radius, nms=self.frontend_nms,
thresh=self.frontend_thresh, beta=self.beta, remove=True)
# Initialize current cam0_depths with the sensed depths if valid,
# otherwise with the previous mean depth (during init it is the mean amongst last 4 frames)
# (perhaps try with the previous raw depth)
self.cam0_idepths[self.kf_idx] = torch.where(self.cam0_idepths_sensed[self.kf_idx] > 0,
self.cam0_idepths_sensed[self.kf_idx],
self.cam0_idepths[self.kf_idx])
# TODO: should we lower the sigmas of idepth here since it is a measurement
#ic("First update")
for itr in range(self.iters1):
x0, rcm_factor= self.update(kf0=None, kf1=None, use_inactive=True)
# Get distance between the previous keyframe (kf_idx-2) and the current frame (kf_idx-1)
d = self.distance([self.kf_idx-2], [self.kf_idx-1], beta=self.beta, bidirectional=True)
#ic(d.item())
if d.item() < self.keyframe_thresh:
#ic("Not a keyframe")
return False
else:
#ic("Found a keyframe")
#ic("Second update")
for itr in range(self.iters2):
x0, rcm_factor = self.update(None, None, use_inactive=True)
# TODO: I believe this should be inside this conditional rather than outside,
# Because in the previous case we decided not to accept the keyframe...
# set pose for next iteration
next_kf = self.kf_idx + 1
if next_kf < self.buffer: # if we have not reached the end of the buffer (aka end of sequence)
self.cam0_T_world[next_kf] = self.cam0_T_world[self.kf_idx]
self.world_T_body[next_kf] = self.world_T_body[self.kf_idx]
self.world_T_body_cov[next_kf] = self.world_T_body_cov[self.kf_idx]
# Why not just the previous depths as in init?
#self.cam0_idepths[next_kf] = self.cam0_idepths[self.kf_idx]
self.cam0_idepths[next_kf] = self.cam0_idepths[self.kf_idx].mean()
self.cam0_idepths_cov[next_kf] = self.cam0_idepths_cov[self.kf_idx]
self.cam0_depths_cov[next_kf] = self.cam0_depths_cov[self.kf_idx]
#self.viz_idx[next_kf] = True
return True
def __initialize(self):
""" initialize the SLAM system """
assert(self.kf_idx > 4)
assert(self.kf_idx >= self.keyframe_warmup)
# Just adds the `radius' sequential frames to the graph
self.add_neighborhood_factors(kf0=0, kf1=self.kf_idx, radius=3)
for _ in range(8):
x0, rcm_factor = self.update(kf0=None, kf1=None, use_inactive=True)
# Adds factors between frames, but unsure how this actually works...
# if kf0 and kf1 are 0, then it tries to add prox factors btw all keyframes
# t = self.kf_idx
# ix = torch.arange(kf0, t)
# jx = torch.arange(kf1, t)
#ic("Add proximity factors")
self.add_proximity_factors(kf0=0, kf1=0, rad=2, nms=2,
thresh=self.frontend_thresh, remove=False)
for _ in range(8):
x0, rcm_factor = self.update(kf0=None, kf1=None, use_inactive=True)
# TODO: next kf_idx shouldn't be kf_idx+1?
# Set initial pose/depth for next iteration
self.cam0_T_world[self.kf_idx + 1] = self.cam0_T_world[self.kf_idx].clone()
self.world_T_body[self.kf_idx + 1] = self.world_T_body[self.kf_idx].clone()
self.world_T_body_cov[self.kf_idx + 1] = self.world_T_body_cov[self.kf_idx].clone()
# TODO: Next depth is just the mean of the previous 4 depths?
# We just retrieve our global implicit map here....
self.cam0_idepths[self.kf_idx + 1] = self.cam0_idepths[self.kf_idx - 3:self.kf_idx+1].mean()
# TODO: here we are doing something very wrong and it is to previous sigmas and worst the initial sigma which is just one
# because I don't have a better initialization... but looking at the numbers one is very confident...
self.cam0_idepths_cov[self.kf_idx + 1] = self.cam0_idepths_cov[self.kf_idx-3:self.kf_idx+1].mean()
self.cam0_depths_cov[self.kf_idx + 1] = self.cam0_depths_cov[self.kf_idx-3:self.kf_idx+1].mean()
# initialization complete
self.is_initialized = True
# Update visualization
self.viz_idx[:self.kf_idx+1] = True
# TODO: what is this 4 here? What if warmup is less than 4...
# Remove edges that point to the first keyframes...
#ic("Remove factors after init, and storing")
self.rm_factors(self.ii < (self.keyframe_warmup - 4), store=True)
return x0, rcm_factor
def add_neighborhood_factors(self, kf0, kf1, radius=3):
""" add edges between neighboring frames within radius r """
# Build dense adjacency matrix, +1 bcs arange stops but does not include the last index.
ii, jj = torch.meshgrid(torch.arange(kf0, kf1+1), torch.arange(kf0, kf1+1))
ii = ii.reshape(-1).to(dtype=torch.long, device=self.device)
jj = jj.reshape(-1).to(dtype=torch.long, device=self.device)
c = 1 if self.stereo else 0
# Remove from dense adjacency matrix those that are not within `r' frames of each other
# This basically keeps the `r' sub-diagonals at the top/bottom of the diagonal.
distances = torch.abs(ii - jj)
keep_radius = distances <= radius
keep_stereo = distances > c # TODO: not sure why this is like this...
keep = keep_stereo & keep_radius
# Computes the correlation between these frames
self.add_factors(ii[keep], jj[keep])
# TODO: not sure how this really works...
def add_proximity_factors(self, kf0=0, kf1=0, rad=2, nms=2, beta=0.25, thresh=16.0, remove=False):
""" add edges to the factor graph based on distance """
t = self.kf_idx + 1
ix = torch.arange(kf0, t)
jx = torch.arange(kf1, t)
ii, jj = torch.meshgrid(ix, jx)
ii = ii.reshape(-1)
jj = jj.reshape(-1)
d = self.distance(ii, jj, beta=beta)
d[(ii - rad) < jj] = np.inf # Set closer than rad frames distance to infinity
d[d > 100] = np.inf # Set distances greater than 100 to infinity
ii1 = torch.cat([self.ii, self.ii_bad, self.ii_inactive], 0)
jj1 = torch.cat([self.jj, self.jj_bad, self.jj_inactive], 0)
for i, j in zip(ii1.cpu().numpy(), jj1.cpu().numpy()):
for di in range(-nms, nms+1):
for dj in range(-nms, nms+1):
if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0):
i1 = i + di
j1 = j + dj
if (kf0 <= i1 < t) and (kf1 <= j1 < t):
d[(i1-kf0)*(t-kf1) + (j1-kf1)] = np.inf
es = []
for i in range(kf0, t):
if self.stereo:
es.append((i, i))
d[(i-kf0)*(t-kf1) + (i-kf1)] = np.inf
for j in range(max(i-rad-1,0), i):
es.append((i,j))
es.append((j,i))
d[(i-kf0)*(t-kf1) + (j-kf1)] = np.inf
ix = torch.argsort(d)
for k in ix:
if d[k].item() > thresh:
continue
if len(es) > self.max_factors:
break
i = ii[k]
j = jj[k]
# bidirectional
es.append((i, j))
es.append((j, i))
for di in range(-nms, nms+1):
for dj in range(-nms, nms+1):
if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0):
i1 = i + di
j1 = j + dj
if (kf0 <= i1 < t) and (kf1 <= j1 < t):
d[(i1-kf0)*(t-kf1) + (j1-kf1)] = np.inf
ii, jj = torch.as_tensor(es, device=self.device).unbind(dim=-1)
self.add_factors(ii, jj, remove)
def distance(self, ii=None, jj=None, beta=0.3, bidirectional=True):
""" frame distance metric """
return_distance_matrix = False
if ii is None:
return_distance_matrix = True
N = self.kf_idx
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii, jj = self.format_indicies(ii, jj, device=self.device)
if bidirectional:
poses = self.cam0_T_world[:self.kf_idx+1].clone() # TODO: why clone? why +1?
d1 = droid_backends.frame_distance(poses, self.cam0_idepths, self.cam0_intrinsics[0], ii, jj, beta)
d2 = droid_backends.frame_distance(poses, self.cam0_idepths, self.cam0_intrinsics[0], jj, ii, beta)
d = .5 * (d1 + d2)
else:
d = droid_backends.frame_distance(self.cam0_T_world, self.cam0_idepths, self.cam0_intrinsics[0], ii, jj, beta)
if return_distance_matrix:
return d.reshape(N, N)
return d
# Computes correlations for each edge in (ii, jj)
# Inits gru_hidden_states for each edge from self.contexts_imgs
# Computes the gru_context_input from the cst_contexts_imgs
# Inits gru_estimated_flow to the reprojected flow
# Inits gru_estimated_flow_weight to 0.0
@torch.cuda.amp.autocast(enabled=True)
def add_factors(self, ii, jj, remove=False):
""" add edges to factor graph """
if not isinstance(ii, torch.Tensor):
ii = torch.as_tensor(ii, dtype=torch.long, device=self.device)
if not isinstance(jj, torch.Tensor):
jj = torch.as_tensor(jj, dtype=torch.long, device=self.device)
# remove duplicate edges, duplication happens because neighborhood and proximity factors
# may overlap.
ii, jj = self.__filter_repeated_edges(ii, jj)
if ii.shape[0] == 0:
return
# place limit on number of factors
old_factors_count = self.ii.shape[0]
new_factors_count = ii.shape[0]
if self.max_factors > 0 and \
old_factors_count + new_factors_count > self.max_factors \
and self.correlation_volumes is not None and remove:
ix = torch.arange(len(self.age))[torch.argsort(self.age).cpu()]
#ic("Remove old factors, and storing.")
self.rm_factors(ix >= (self.max_factors - new_factors_count), store=True)
### Add new factors
self.ii = torch.cat([self.ii, ii], 0)
self.jj = torch.cat([self.jj, jj], 0)
self.age = torch.cat([self.age, torch.zeros_like(ii)], 0)
### Add correlation volumes for new edges, if we do not use alt implementation
# that computes correlations on the fly... (a bit slower for some reason, but saves
# tons of memory, since most of the corr volume will be unexplored I think)
if self.corr_impl == "volume":
is_stereo = (ii == jj).long() # TODO: does this still hold for multi-camera?
feature_img_ii = self.features_imgs[None, ii, 0]
feature_img_jj = self.features_imgs[None, jj, is_stereo]
corr = CorrBlock(feature_img_ii, feature_img_jj)
self.correlation_volumes = self.correlation_volumes.cat(corr) \
if self.correlation_volumes is not None else corr
### Gru hidden states are initialized to the context features, and then they evolve
gru_hidden_state = self.contexts_imgs[None, ii, 0] # Only for cam0
self.gru_hidden_states = torch.cat([self.gru_hidden_states, gru_hidden_state], 1) \
if self.gru_hidden_states is not None else gru_hidden_state
### Gru input states are initialized to the context features, and they do not evolve
gru_context_input = self.cst_contexts_imgs[None, ii, 0] # Only for cam0
self.gru_contexts_input = torch.cat([self.gru_contexts_input, gru_context_input], 1) \
if self.gru_contexts_input is not None else gru_context_input
### Gru estimated flow is initialized to the reprojected flow, and then it evolves
with torch.cuda.amp.autocast(enabled=False):
target, _, _ = self.reproject(ii, jj)
weight = torch.zeros_like(target) # Initialize weights to 0
# TODO: not sure why we concat gru_estimated_flow with target, instead of directly init with target...
# The first gru_estimated_flow is init to zero!
self.gru_estimated_flow = torch.cat([self.gru_estimated_flow, target], 1) ## Init gru flow with the one from reprojection!
self.gru_estimated_flow_weight = torch.cat([self.gru_estimated_flow_weight, weight], 1)
# Frees up memory as well,
# TODO it doesn't do any marginalization?
# TODO should we update kf_idx to re-use the empty slots?
@torch.cuda.amp.autocast(enabled=True)
def rm_factors(self, mask, store=False):
""" drop edges from factor graph """
# store estimated factors that are removed
if store:
self.ii_inactive = torch.cat([self.ii_inactive, self.ii[mask]], 0)
self.jj_inactive = torch.cat([self.jj_inactive, self.jj[mask]], 0)
self.gru_estimated_flow_inactive = torch.cat([self.gru_estimated_flow_inactive, self.gru_estimated_flow[:,mask]], 1)
self.gru_estimated_flow_weight_inactive = torch.cat([self.gru_estimated_flow_weight_inactive, self.gru_estimated_flow_weight[:,mask]], 1)
# Actually remove edges
self.ii = self.ii[~mask]
self.jj = self.jj[~mask]
self.age = self.age[~mask]
if self.corr_impl == "volume":
self.correlation_volumes = self.correlation_volumes[~mask]
if self.gru_hidden_states is not None:
self.gru_hidden_states = self.gru_hidden_states[:,~mask]
if self.gru_contexts_input is not None:
self.gru_contexts_input = self.gru_contexts_input[:,~mask]
self.gru_estimated_flow = self.gru_estimated_flow[:,~mask]
self.gru_estimated_flow_weight = self.gru_estimated_flow_weight[:,~mask]
# TODO: what about those keyframes that are isolated because no more factors touch them?
def __filter_repeated_edges(self, ii, jj):
""" remove duplicate edges """
keep = torch.zeros(ii.shape[0], dtype=torch.bool, device=ii.device)
eset = set(
[(i.item(), j.item()) for i, j in zip(self.ii, self.jj)] +
[(i.item(), j.item()) for i, j in zip(self.ii_inactive, self.jj_inactive)])
for k, (i, j) in enumerate(zip(ii, jj)):
keep[k] = (i.item(), j.item()) not in eset
return ii[keep], jj[keep]
def reproject(self, ii, jj, cam_T_body=None, jacobian=False):
""" project points from ii -> jj """
ii, jj = self.format_indicies(ii, jj, device=self.device) # TODO: is this really necessary?
Gs = lietorch.SE3(self.cam0_T_world[None]) # TODO: this is a bit expensive no?
# TODO: It would be great to visualize both
coords, valid_mask, (Ji, Jj, Jz) = \
pops.projective_transform(Gs, self.cam0_idepths[None], self.cam0_intrinsics[None], ii, jj, cam_T_body=cam_T_body, jacobian=jacobian)
return coords, valid_mask, (Ji, Jj, Jz)
def print_edges(self):
ii = self.ii.cpu().numpy()
jj = self.jj.cpu().numpy()
ix = np.argsort(ii)
ii = ii[ix]
jj = jj[ix]
w = torch.mean(self.gru_estimated_flow_weight, dim=[0,2,3,4]).cpu().numpy()
w = w[ix]
for e in zip(ii, jj, w):
print(e)
print()
@staticmethod
def format_indicies(ii, jj, device="cpu"):
""" to device, long, {-1} """
if not isinstance(ii, torch.Tensor):
ii = torch.as_tensor(ii)
if not isinstance(jj, torch.Tensor):
jj = torch.as_tensor(jj)
ii = ii.to(device=device, dtype=torch.long).reshape(-1)
jj = jj.to(device=device, dtype=torch.long).reshape(-1)
return ii, jj
@torch.cuda.amp.autocast(enabled=True)
def __context_encoder(self, images): # image must be normalized
""" Context features """
context_maps, gru_input_maps = self.context_net(images).split([128,128], dim=2)
return context_maps.tanh().squeeze(0), gru_input_maps.relu().squeeze(0)
@torch.cuda.amp.autocast(enabled=True)
def __feature_encoder(self, images): # image must be normalized
""" Features for correlation volume """
return self.feature_net(images).squeeze(0)
# images has shape [c, 3, h, w], where c is the number of images, h and w are the height and width of the images.
# It outputs the images with an extra dimension at the front, the batch dimension b.
# Further, it sends them to the GPU.
def _normalize_imgs(self, images, droid_normalization=True):
img_normalized = images[:,:,:3, ...] / 255.0 # Drop alpha channel
if droid_normalization:
mean = torch.as_tensor([0.485, 0.456, 0.406], device=self.device)[:, None, None]
stdv = torch.as_tensor([0.229, 0.224, 0.225], device=self.device)[:, None, None]
else:
mean = img_normalized.mean(dim=(3,4), keepdim=True)
stdv = img_normalized.std(dim=(3,4), keepdim=True)
img_normalized = img_normalized.sub_(mean).div_(stdv)
return img_normalized
@torch.cuda.amp.autocast(enabled=True)
@torch.no_grad()
def has_enough_motion(self, current_imgs_features):
# Only calculates if enough motion by looking at cam0
ic(self.last_kf_idx)
last_img_features = self.features_imgs[self.last_kf_idx][0]
current_img_features = current_imgs_features[0]
last_img_context = self.contexts_imgs[self.last_kf_idx][0]
last_img_gru_input = self.cst_contexts_imgs[self.last_kf_idx][0]
# Viz hidden features
if self.viz:
self.viz_hidden_features(last_img_features, current_img_features)
# Index correlation volume
corr = CorrBlock(last_img_features[None, None],
current_img_features[None, None])(self.coords0[None, None]) # TODO why not send the corr block?
# Approximate flow magnitude using 1 update iteration
_, delta, weight = self.update_net(last_img_context[None,None], last_img_gru_input[None,None], corr)
# # Viz weight
# if self.viz:
# self.viz_weight(weight)