-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataloader.py
49 lines (38 loc) · 1.64 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch.utils.data import Dataset
from boltons.fileutils import iter_find_files
import torchaudio as taudio
from scipy import signal
def collate_fn_padd(batch):
spects = [t[0] for t in batch]
lengths = [t[1] for t in batch]
fnames = [t[2] for t in batch]
# pad and stack
padded_spects = torch.nn.utils.rnn.pad_sequence(spects, batch_first=True) #batch_first = first dim will be #batch
lengths = torch.LongTensor(lengths)
return padded_spects,lengths,fnames#, padded_formants, padded_heatmaps, phonemes, lengths, fnames, masked_phonemes, n_formants, dataset_str
def preemphasis(x, coeff=0.97):
return torch.from_numpy(signal.lfilter([1, -coeff], [1], x)).float()
def extract_features(wav_file, hp):
wav, sr = taudio.load(wav_file)
if hp.emph>0:
wav=preemphasis(wav,coeff=hp.emph)
spect = taudio.transforms.Spectrogram(n_fft=hp.n_fft,
win_length=hp.n_fft,
hop_length=sr//100,
power=2,
normalized=hp.normalize)(wav)
spect = torch.transpose(spect, 1, 2)[0]
return spect
def get_test_dataset(hp):
return WavDataset(hp)
class WavDataset(Dataset):
def __init__(self,hp):
self.wavs=list(iter_find_files(hp.test_dir, "*.wav"))
self.test_dir= hp.test_dir
self.hp = hp
def __getitem__(self, index):
spect = extract_features(self.wavs[index],self.hp)
return spect,spect.shape[0],self.wavs[index] #spect,lenght,fname
def __len__(self):
return len(self.wavs)