-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_utils.py
134 lines (105 loc) · 5.17 KB
/
dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import whisper
import glob
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchaudio
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, ToPILImage, Normalize
import cv2
import time
import matplotlib.pyplot as plt
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
ToPILImage(),
Resize(n_px),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
# Normalization params from 1/6th of GreatestHits dataset
Normalize((0.4738, 0.4298, 0.3702),
(0.1994, 0.1913, 0.1805)),
])
class GreatestHitsDataset(Dataset):
def __init__(self, root_dir="./vis-data-256", audio_length=5):
self.root_dir = root_dir
self.transform = _transform(224)
# self.transform_to_tensor = T.ToTensor()
self.audio_rate = 96000
self.audio_length = audio_length
self.times_info = {}
for file in sorted(glob.glob(root_dir + "/*_times.txt")):
with open(file, "r") as f:
data = f.readlines()
data = [line.strip().split() for line in data]
data = [(float(line[0]), line[1]) for line in data if line[1] != "None"]
self.times_info[file.split("/")[-1].split("_")[0]] = data
self.all_material_names = set()
for _, frame_info in self.times_info.items():
self.all_material_names.update([frame[1] for frame in frame_info])
self.all_material_names = sorted(list(self.all_material_names))
self.mat_to_ind = {name: i for i, name in enumerate(self.all_material_names)}
self.ind_to_mat = {i: name for i, name in enumerate(self.all_material_names)}
def __len__(self):
total_len = 0
for key in self.times_info.keys():
total_len += len(self.times_info[key])
return total_len
def __getitem__(self, idx):
for key in self.times_info.keys():
if idx < len(self.times_info[key]):
break
else:
idx -= len(self.times_info[key])
date_time = key
# frames, audio, metadata = torchvision.io.read_video(f"./vis-data-256/{date_time}_denoised.mp4")
# Loading single frame with cv2 for faster frame access compared to loading whole video
cap = cv2.VideoCapture(self.root_dir + f"/{date_time}_denoised.mp4")
audio = whisper.load_audio(self.root_dir + f"/{date_time}_denoised.wav", self.audio_rate)
frames_info = self.times_info[date_time]
frame_timestamp = frames_info[idx][0]
material_name = frames_info[idx][1]
material_index = self.mat_to_ind[material_name]
material_index = torch.tensor(material_index)
audio_start_time = frame_timestamp - self.audio_length / 2
audio_start_idx = int(audio_start_time * self.audio_rate)
audio = audio[audio_start_idx : audio_start_idx + self.audio_rate * self.audio_length]
audio = whisper.pad_or_trim(audio, self.audio_rate * self.audio_length)
# print(date_time, frame_timestamp, audio.shape)
mel = whisper.log_mel_spectrogram(audio)
cap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp * 1000)
ret, frame = cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# frame = (self.transform_to_tensor(frame) * 255).to(torch.uint8)
frame = self.transform(frame)
cap.release()
# return {"images": frame, "audios": mel, "audios_raw":audio, "materials": material_index}
return {"images": frame, "audios": mel, "materials": material_index}
# Create DataLoader
def create_dataloaders(root_dir="./vis-data-256", batch_size=4, val_ratio=0.05):
dataset = GreatestHitsDataset(root_dir)
print(f"\nDataset size: {len(dataset)}\n")
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_set, val_set = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
return train_loader, val_loader
if __name__ == '__main__':
batch_size = 4
start_time = time.time()
train_loader, _ = create_dataloaders(root_dir="/home/GreatestHits/vis-data-256", batch_size=batch_size)
print(f"Time taken to create dataloader (with batch size {batch_size}): {time.time() - start_time:.2f} seconds")
start_time = time.time()
for i, data in enumerate(train_loader):
print([data[k].shape for k in data])
print(f"\nTime taken to load batch {i+1}: {time.time() - start_time:.2f} seconds")
plt.imsave(f"image_{i+1}.jpg", data["images"][0].permute(1, 2, 0).numpy().clip(0, 1))
mel = data["audios"][0].numpy()
plt.imsave(f"mel_{i+1}.jpg", mel, cmap="viridis")
# print(f"Min, Max for mel: {mel.min(), mel.max()}")
if "audios_raw" in data:
torchaudio.save(f"audio_{i+1}.wav", data["audios_raw"][0].unsqueeze(0), 96000)
if i == 1:
break
start_time = time.time()