Skip to content

Commit

Permalink
Merge branch 'kws/signalmixer' into kws/benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
EyubogluMerve authored Apr 4, 2024
2 parents 76b784f + 9968e78 commit 991bdc0
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 538 deletions.
73 changes: 66 additions & 7 deletions datasets/kws20.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __resample_convert_wav(self, folder_in, folder_out, sr=16000, ext='.flac'):

precursor_len = 30 * 128
postcursor_len = 98 * 128
utternace_threshold = 30
utterance_threshold = 30

while True:
if chunk_start + postcursor_len > len(data):
Expand All @@ -391,7 +391,7 @@ def __resample_convert_wav(self, folder_in, folder_out, sr=16000, ext='.flac'):
avg = 1000 * np.average(abs(chunk))
i += 128

if avg > utternace_threshold and chunk_start >= precursor_len:
if avg > utterance_threshold and chunk_start >= precursor_len:
print(f"\r Converting {converted_count + 1}/{total_count} "
f"to {frame_count + 1} segments", end=" ")
frame = data[chunk_start - precursor_len:chunk_start + postcursor_len]
Expand Down Expand Up @@ -640,7 +640,7 @@ def quantize_audio(data, num_bits=8, compand=False, mu=255):
q_data = np.clip(q_data, 0, max_val)
return np.uint8(q_data)

def energy_detector(self, audio, fs):
def get_audio_endpoints(self, audio, fs):
"""Future: May implement a method to detect the beginning & end of voice activity in audio.
Currently, it returns end points compatible with augmentation['shift'] values
"""
Expand All @@ -667,7 +667,7 @@ def speed_augment_multiple(self, audio, fs, exp_len, n_augment):
aug_audio = [None] * (n_augment + 1)
aug_speed = np.ones((n_augment + 1,))
shift_limits = np.zeros((n_augment + 1, 2))
voice_begin_idx, voice_end_idx = self.energy_detector(audio, fs)
voice_begin_idx, voice_end_idx = self.get_audio_endpoints(audio, fs)
aug_audio[0] = audio
for i in range(n_augment):
aug_audio[i+1], aug_speed[i+1] = self.speed_augment(audio, fs, sample_no=i)
Expand Down Expand Up @@ -874,7 +874,7 @@ def KWS_get_datasets(data, load_train=True, load_test=True, num_classes=6, bench
Data is augmented to 3x duplicate data by random stretch/shift and randomly adding noise where
the stretching coefficient, shift amount and noise SNR level are randomly selected between
0.8 and 1.3, -0.1 and 0.1, 5 and 30, respectively.
0.8 and 1.3, -0.1 and 0.1, -5 and 20, respectively.
"""
(data_dir, args) = data

Expand Down Expand Up @@ -1000,7 +1000,10 @@ def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True,
noise_type --> All noise types in the noise dataset.
"""

snr_range = range(snr_range[0], snr_range[1])
if len(snr_range) > 1:
snr_range = range(snr_range[0], snr_range[1])
else:
snr_range = list(snr_range)

(data_dir, _) = data

Expand All @@ -1027,7 +1030,6 @@ def KWS_20_msnoise_mixed_get_datasets(data, load_train=True, load_test=True,

return train_dataset, test_dataset


def KWS_12_benchmark_get_datasets(data, load_train=True, load_test=True):
"""
Returns the KWS dataset benchmark for 12 classes. 10 keywords and
Expand All @@ -1036,6 +1038,55 @@ def KWS_12_benchmark_get_datasets(data, load_train=True, load_test=True):
return KWS_get_datasets(data, load_train, load_test, num_classes=11, benchmark=True)


def MixedKWS_20_get_datasets_10dB(data, load_train=True, load_test=True,
apply_prob=1, snr_range=tuple([10]),
noise_type=MSnoise.class_dict.keys(),
desired_probs=None):
"""
Returns the mixed KWS dataset with MSnoise dataset under 10 dB SNR using signalmixer
data loader. All of the training and test data will be augmented with
additional noise.
"""

if len(snr_range) > 1:
snr_range = range(snr_range[0], snr_range[1])
else:
snr_range = list(snr_range)

(data_dir, _) = data

kws_train_dataset, kws_test_dataset = KWS_20_get_datasets(
data, load_train, load_test)

if load_train:
noise_dataset_train = MSnoise(root=data_dir, classes=noise_type,
d_type='train', dataset_len=len(kws_train_dataset),
desired_probs=desired_probs,
transform=None, quantize=False, download=False)

train_dataset = signalmixer(signal_dataset=kws_train_dataset,
snr_range=snr_range,
noise_type=noise_type, apply_prob=apply_prob,
noise_dataset=noise_dataset_train)
else:
train_dataset = None

if load_test:
noise_dataset_test = MSnoise(root=data_dir, classes=noise_type,
d_type='test', dataset_len=len(kws_test_dataset),
desired_probs=desired_probs,
transform=None, quantize=False, download=False)

test_dataset = signalmixer(signal_dataset=kws_test_dataset,
snr_range=snr_range,
noise_type=noise_type, apply_prob=apply_prob,
noise_dataset=noise_dataset_test)
else:
test_dataset = None

return train_dataset, test_dataset


datasets = [
{
'name': 'KWS', # 6 keywords
Expand Down Expand Up @@ -1082,5 +1133,13 @@ def KWS_12_benchmark_get_datasets(data, load_train=True, load_test=True):
'silence', 'UNKNOWN'),
'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6, 0.06),
'loader': KWS_12_benchmark_get_datasets,
},
'name': 'MixedKWS20_10dB',
'input': (128, 128),
'output': ('up', 'down', 'left', 'right', 'stop', 'go', 'yes', 'no', 'on', 'off', 'one',
'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'zero',
'UNKNOWN'),
'weight': (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.07),
'loader': MixedKWS_20_get_datasets_10dB,
}
]
Loading

0 comments on commit 991bdc0

Please sign in to comment.