-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaudio_dataset.py
66 lines (53 loc) · 2.09 KB
/
audio_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os.path
import pandas as pd
import torch
import torchaudio
from torch.utils.data import Dataset
NEW_COLUMN_NAMES = {
'---g-f_I2yQ': 'youtube_video_id',
'1': 'start_seconds',
'people marching': 'label',
'test': 'split',
}
class AudioDataset(Dataset):
def __init__(self, csv_file, audio_dir, split='train'):
if split not in ['train', 'val']:
raise ValueError('Split must be either "train" or "val"')
split = 'test' if split == 'val' else 'train'
self.audio_dir = audio_dir
self.df = pd.read_csv(csv_file)
self.rename_columns()
self.add_columns()
self.df = self.df[self.df['split'] == split]
self.df['label'] += ' '
self.df['label'] = self.df['label'] * 100
self.remove_invalid_rows()
@staticmethod
def check_validity(file_path):
return os.path.isfile(file_path)
@staticmethod
def transform_waveform(waveform, sampling_rate):
new_frequency = 16000
transform = torchaudio.transforms.Resample(sampling_rate, new_frequency)
waveform = transform(waveform)
mono_waveform = torch.mean(waveform, dim=0, keepdim=True)
return mono_waveform, new_frequency
def remove_invalid_rows(self):
self.df['is_valid'] = self.df['audio_path'].apply(AudioDataset.check_validity)
self.df = self.df[self.df['is_valid']]
self.df = self.df.drop(columns=['is_valid'])
def rename_columns(self):
self.df.rename(columns=NEW_COLUMN_NAMES, inplace=True)
def add_columns(self):
self.df['audio_path'] = self.df['youtube_video_id'].apply(
lambda x: self.audio_dir + '/' + 'audio_' + x + '.wav')
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
label = self.df.iloc[idx]['label']
audio_path = self.df.iloc[idx]['audio_path']
waveform, sample_rate = torchaudio.load(audio_path, channels_first=True)
waveform, sample_rate = AudioDataset.transform_waveform(waveform, sample_rate)
return waveform, label
def get_df(self):
return self.df