-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
101 lines (80 loc) · 3.94 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import numpy as np
import math
from scipy.spatial.transform import Rotation as R
import torch
from config import config
class Dataset():
def __init__(self, dataset_name):
root_folder = os.path.dirname(os.path.abspath(os.getcwd()))
self.kinect_dataset_path = os.path.join(root_folder, "Data", dataset_name)
self.make_dataset(self.kinect_dataset_path)
self.preprocess(self.kinect_dataset_path)
def make_dataset(self, directory):
self.depth_path = os.path.join(directory,"depth.txt")
self.rgb_path = os.path.join(directory,"rgb.txt")
self.groundtruth_path = os.path.join(directory,"groundtruth.txt")
if(not os.path.exists(self.depth_path)):
raise Exception("Depth.txt not found. Please make sure that you downloaded the dataset and is storing it in ../Data/")
if(not os.path.exists(self.rgb_path)):
raise Exception("RGB.txt not found. Please make sure that you downloaded the dataset and is storing it in ../Data/")
if(not os.path.exists(self.groundtruth_path)):
raise Exception("Groundtruth.txt not found. Please make sure that you downloaded the dataset and is storing it in ../Data/")
depth_file = open(self.depth_path, "r")
rgb_file = open(self.rgb_path, "r")
groundtruth_file = open(self.groundtruth_path, "r")
self.number_frames = sum(1 for line in depth_file)
self.rgb_number_frames = sum(1 for line in rgb_file)
if(self.number_frames != self.rgb_number_frames):
print(self.number_frames, self.rgb_number_frames)
raise Exception("Depth and Color files frames do not match. Please make sure that they have the same number of lines")
depth_file.close()
rgb_file.close()
def preprocess(self, directory):
depth_file = open(self.depth_path, "r")
rgb_file = open(self.rgb_path, "r")
groundtruth_file = open(self.groundtruth_path, "r")
self.rgb_timestamps = []
self.rgb_images_path = []
self.depth_timestamps = []
self.depth_images_path = []
self.trajectory_timestamps = []
self.trajectory = []
current_frame = 0
for depth_frame,color_frame in zip(depth_file, rgb_file):
if(current_frame<3):
current_frame +=1
continue
depth_split = depth_frame.split()
color_split = color_frame.split()
if(len(depth_split)<2):
raise Exception("File does not follow typical format")
if(len(color_split)<2):
raise Exception("File does not follow typical format")
## Loading depth images
self.depth_timestamps.append(depth_split[0])
self.depth_images_path.append(depth_split[1])
## Loading RGB Images
self.rgb_timestamps.append(color_split[0])
self.rgb_images_path.append(color_split[1])
## Loading trajectory details
current_frame = 0
for trajectory_frame in groundtruth_file:
if(current_frame<3):
current_frame +=1
continue
trajectory_split = trajectory_frame.split()
if(len(trajectory_split)<2):
raise Exception("File does not follow typical format")
self.trajectory_timestamps.append(trajectory_split[0])
temp_traj = trajectory_split[1:]
trajMatrix = torch.eye(4)
trajMatrix[0,3] = float(temp_traj[0])
trajMatrix[1,3] = float(temp_traj[1])
trajMatrix[2,3] = float(temp_traj[2])
rotMatrix = R.from_quat(temp_traj[3:]).as_matrix()
trajMatrix[0:3,0:3] = torch.from_numpy(rotMatrix)
if(np.linalg.norm(rotMatrix)==0):
break
trajInverse = torch.linalg.inv(trajMatrix)
self.trajectory.append(trajInverse)