-
Notifications
You must be signed in to change notification settings - Fork 1
/
Dataloaders.py
114 lines (99 loc) · 3.67 KB
/
Dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cv2
import numpy
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import UCF101
from torchvision.transforms import transforms
from pathlib import Path
import os
import shutil
class customVideoDataset(Dataset):
def __init__(self, path, frame_count):
self.videos = []
self.labels = []
self.frames = frame_count
folder = Path(path)
for label in sorted(os.listdir(folder)):
for fname in os.listdir(os.path.join(folder, label)):
self.videos.append(os.path.join(folder, label, fname))
self.labels.append(label)
self.label2index = {label: index for index, label in enumerate(sorted(set(self.labels)))}
self.label_array = numpy.array([self.label2index[label] for label in self.labels], dtype=int)
def __getitem__(self, idx):
video = cv2.VideoCapture(self.videos[idx])
stacked_frames = numpy.empty(shape=(self.frames, 32, 32, 3),
dtype=numpy.dtype('float16')) # as frame would have shape h,w,channels
frame_count = 0
while video.isOpened() and frame_count<self.frames:
ret, frame = video.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (32, 32))
stacked_frames[frame_count] = frame
frame_count += 1
video.release()
stacked_frames = stacked_frames.transpose((3, 0, 1, 2))
return stacked_frames, self.label_array[idx]
def __len__(self):
length = len(self.videos)
return length
def getDataloader(path, batch, workers, frames):
dataset = customVideoDataset(path=path, frame_count=frames)
dataloader = DataLoader(dataset, batch_size=batch, num_workers=workers, shuffle=True)
return dataloader
#Run this once to get train test split of UFC101
def ufctraintest(root_dir, annotation_dir, target_dir):
os.chdir(annotation_dir)
files = os.listdir()
train = []
test = []
for file in files:
if 'train' in file:
train.append(file)
if 'test' in file:
test.append(file)
classes = open('classInd.txt', 'r')
for Class in classes:
try:
os.chdir(target_dir + '/train')
os.mkdir(Class.split()[1])
os.chdir(target_dir + '/test')
os.mkdir(Class.split()[1])
except:
pass
classes.close()
os.chdir(target_dir + '/train')
classes = os.listdir()
print(classes)
print(len(classes))
print('MOVING TRAINING FILES')
for file in train:
print(train)
line = open(annotation_dir +'/'+file, 'r')
for video in line:
video = video.split('/')[1]
for Class in classes:
if Class in video:
try:
shutil.move(src=root_dir + '/' + video.split()[0], dst=target_dir + '/train/' + Class + '/')
except Exception as e:
print(e)
line.close()
os.chdir(target_dir + '/test')
classes = os.listdir()
print(classes)
print(len(classes))
print('MOVING TEST FILES')
for file in test:
print(test)
line = open(annotation_dir +'/'+file, 'r')
for video in line:
video =video.split('/')[1]
for Class in classes:
if Class in video:
print(video)
try:
shutil.move(src=root_dir + '/' + video, dst=target_dir + '/test/' + Class + '/')
except Exception as e:
print(e)
line.close()