-
Notifications
You must be signed in to change notification settings - Fork 5
/
data_loader.py
51 lines (41 loc) · 1.85 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
from glob import glob
from torch.utils.data.sampler import SubsetRandomSampler
from transform import audio_transform
class UrbanSoundDataset(torch.utils.data.Dataset):
def __init__(self, paths, info_df, transform=None):
self.transform = transform
self.info_df = info_df
self.file_list = []
for path in paths:
self.file_list.extend(glob(os.path.join(path, "*.wav")))
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
filepath = self.file_list[idx]
_, filename = os.path.split(filepath)
label = int(self.info_df[self.info_df["slice_file_name"] == filename]["classID"])
waveform, sample_rate = torchaudio.load(filepath)
if self.transform:
waveform = self.transform(waveform)
return waveform, label
def load_urbansound8k(data_path, batch_size, shuffle_dataset, random_seed=42):
train_paths = [os.path.join(data_path, f"fold{i}") for i in range (1, 10)]
test_paths = [os.path.join(data_path, "fold10")]
info_df = pd.read_csv(os.path.join(data_path, "UrbanSound8K.csv"))
train_data = UrbanSoundDataset(train_paths, info_df, audio_transform)
test_data = UrbanSoundDataset(test_paths, info_df, audio_transform)
# Creating data indices for training and validation splits:
train_indices = list(range(len(train_data)))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(train_indices)
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
return train_loader, test_loader