Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kws20.py with %66 augment chance #2

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 57 additions & 40 deletions datasets/kws20.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def __parse_augmentation(self, augmentation):
print('No key `shift` in input augmentation dictionary! '
'Using defaults: [Min:-0.1, Max: 0.1]')
self.augmentation['shift'] = {'min': -0.1, 'max': 0.1}
if 'strech' not in augmentation:
print('No key `strech` in input augmentation dictionary! '
if 'stretch' not in augmentation:
print('No key `stretch` in input augmentation dictionary! '
'Using defaults: [Min: 0.8, Max: 1.3]')
self.augmentation['strech'] = {'min': 0.8, 'max': 1.3}
self.augmentation['stretch'] = {'min': 0.8, 'max': 1.3}

def __download(self):

Expand Down Expand Up @@ -201,11 +201,7 @@ def __gen_datasets(self, exp_len=16384):
record_list = sorted(os.listdir(os.path.join(self.raw_folder, label)))
record_len = len(record_list)

if not self.save_unquantized:
data_in = np.empty((record_len,
exp_len), dtype=np.uint8)
else:
data_in = np.empty((record_len,
data_in = np.empty((record_len,
exp_len), dtype=np.float32)


Expand Down Expand Up @@ -234,15 +230,7 @@ def __gen_datasets(self, exp_len=16384):
record = np.pad(record, [0, exp_len - record.size])

data_type[r , 0] = d_typ

if not self.save_unquantized:
data_in[r] = \
KWS.quantize_audio(record,
num_bits=self.quantization['bits'],
compand=self.quantization['compand'],
mu=self.quantization['mu'])
else:
data_in[r] = record
data_in[r] = record

dur = time.time() - time_s
print(f'Finished in {dur:.3f} seconds.')
Expand All @@ -269,27 +257,36 @@ def __gen_datasets(self, exp_len=16384):
def __dynamic_augment(self, record, fs = 16000, verbose=False, exp_len=16384, row_len=128, overlap_ratio=0):

audio = self.augment(record, fs)
audio = np.array(audio, np.uint8)
data_in = self.reshape_file(audio)

return data_in

def reshape_file(self, audio, row_len = 128, exp_len=16384, overlap_ratio=0):

overlap = int(np.ceil(row_len * overlap_ratio))
num_rows = int(np.ceil(exp_len / (row_len - overlap)))
data_len = int((num_rows * row_len - (num_rows - 1) * overlap))

if not self.save_unquantized:
data_in = np.empty((row_len, num_rows), dtype=np.uint8)
data_in = np.empty((row_len, num_rows), dtype=np.uint8)
else:
data_in = np.empty((row_len, num_rows), dtype=np.float32)

for n_r in range(num_rows):
start_idx = n_r * (row_len - overlap)
end_idx = start_idx + row_len
audio_chunk = audio[start_idx:end_idx]
audio_chunk = np.pad(audio_chunk, [0, row_len - audio_chunk.size])
audio_chunk = np.pad(audio_chunk, [0, row_len - audio_chunk.shape[0]])

data_in[:, n_r] = audio_chunk
if not self.save_unquantized:
data_in[:, n_r] = \
KWS.quantize_audio(audio_chunk,
num_bits=self.quantization['bits'],
compand=self.quantization['compand'],
mu=self.quantization['mu'])
else:
data_in[:, n_r] = audio_chunk

data_in = torch.from_numpy(data_in)

return data_in


Expand Down Expand Up @@ -463,7 +460,7 @@ def __filter_dtype(self):

self.data = self.data[idx_to_select, :]
self.targets = self.targets[idx_to_select, :]
del self.data_type
self.data_type = self.data_type[idx_to_select, :]


def __filter_classes(self):
Expand All @@ -487,9 +484,12 @@ def __len__(self):
return len(self.data)

def __getitem__(self, index):
inp, target = self.data[index], int(self.targets[index])
inp, target, data_type = self.data[index], int(self.targets[index]), self.data_type[index]

inp = self.__dynamic_augment(inp)
if data_type == 0:
inp = self.__dynamic_augment(inp)
else:
inp = self.reshape_file(inp)

inp = inp.type(torch.FloatTensor)

Expand All @@ -504,21 +504,22 @@ def __getitem__(self, index):
def add_white_noise(audio, noise_var_coeff):
"""Adds zero mean Gaussian noise to image with specified variance.
"""
coeff = noise_var_coeff * np.mean(np.abs(audio))
noisy_audio = audio + coeff * np.random.randn(len(audio))
audio_mean = torch.mean(torch.abs(audio))
coeff = noise_var_coeff * audio_mean
noisy_audio = audio + coeff * torch.randn(len(audio))
return noisy_audio

@staticmethod
def shift(audio, shift_sec, fs):
def shift(audio, shift_sec, fs=16000):
"""Shifts audio.
"""
shift_count = int(shift_sec * fs)
return np.roll(audio, shift_count)
return torch.roll(audio, shift_count)

@staticmethod
def stretch(audio, rate=1):
"""Stretches audio with specified ratio.
"""
"""
input_length = 16000
audio2 = librosa.effects.time_stretch(audio, rate)
if len(audio2) > input_length:
Expand All @@ -527,6 +528,11 @@ def stretch(audio, rate=1):
audio2 = np.pad(audio2, (0, max(0, input_length - len(audio2))), "constant")

return audio2

@staticmethod
def stretch_(audio, rate=1):
return torch.from_numpy(tsm.wsola(audio, rate))


def augment(self, audio, fs, verbose=False):
"""Augments audio by adding random noise, shift and stretch ratio.
Expand All @@ -535,17 +541,28 @@ def augment(self, audio, fs, verbose=False):
self.augmentation['noise_var']['max'])
random_shift_time = np.random.uniform(self.augmentation['shift']['min'],
self.augmentation['shift']['max'])
random_strech_coeff = np.random.uniform(self.augmentation['strech']['min'],
self.augmentation['strech']['max'])
random_stretch_coeff = np.random.uniform(self.augmentation['stretch']['min'],
self.augmentation['stretch']['max'])

augment_methods = {
"noise_var": [self.add_white_noise, random_noise_var_coeff],
"shift": [self.shift, random_shift_time],
"stretch": [self.stretch_, random_stretch_coeff]
}

for option in augment_methods:
# %66 possibility to apply an augmentation
if np.random.randint(3) > 0:
aug_func = augment_methods[option][0]
audio = aug_func(audio, augment_methods[option][1])
else:
continue

aug_audio = tsm.wsola(audio, random_strech_coeff)
aug_audio = self.shift(aug_audio, random_shift_time, fs)
aug_audio = self.add_white_noise(aug_audio, random_noise_var_coeff)

if verbose:
print(f'random_noise_var_coeff: {random_noise_var_coeff:.2f}\nrandom_shift_time: \
{random_shift_time:.2f}\nrandom_strech_coeff: {random_strech_coeff:.2f}')
return aug_audio
{random_shift_time:.2f}\nrandom_stretch_coeff: {random_stretch_coeff:.2f}')
return audio

def augment_multiple(self, audio, fs, n_augment, verbose=False):
"""Calls `augment` function for n_augment times for given audio data.
Expand Down Expand Up @@ -628,7 +645,7 @@ def KWS_get_datasets(data, load_train=True, load_test=True, num_classes=6):
raise ValueError(f'Unsupported num_classes {num_classes}')

augmentation = {'aug_num': 2, 'shift': {'min': -0.15, 'max': 0.15},
'noise_var': {'min': 0, 'max': 1}}
'noise_var': {'min': 0, 'max': 0.05}}
quantization_scheme = {'compand': False, 'mu': 10}

if load_train:
Expand Down
73 changes: 73 additions & 0 deletions utils/log_parser
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3

"""
Training+validation log parser
"""

import re
import matplotlib.pyplot as plt

def log_parser(log_path, plot_loss = True, plot_acc = True):# Regular expressions to extract relevant information

# Target values to locate
epoch_pattern = re.compile(r'Epoch: \[(\d+)\]')
loss_pattern = re.compile(r'Overall Loss (\d+\.\d+)')
validation_loss_pattern = re.compile(r'Loss\s*([\d.]+)')
top1_pattern = r'Top1\s*([\d.]+)'

# Open and read the log file
with open(log_path, 'r') as log_file:
log_contents = log_file.read()

# Find corresponding values
epoch_matches = re.findall(epoch_pattern, log_contents)
loss_matches = re.findall(loss_pattern, log_contents)
validation_loss_matches = re.findall(validation_loss_pattern, log_contents)
top1_matches = re.findall(top1_pattern, log_contents)

# Convert extracted data to appropriate data types
epochs = [int(match) for match in epoch_matches]
losses = [float(match) for match in loss_matches]
validation_losses = [float(match) for match in validation_loss_matches]
top1_accuracies = [float(match) for match in top1_matches]

# Make sure training and validation losses have the same length
min_length = min(len(losses), len(validation_losses))
epochs = epochs[:min_length]
training_losses = losses[:min_length]
validation_losses = validation_losses[:min_length]

# Make sure training and validation losses have the same length
top1_length = min(len(epochs), len(top1_accuracies))
top1_epochs = epochs[:top1_length]
top1_accuracies = top1_accuracies[:top1_length]

# Plot training + val loss vs epoch
if plot_loss:
plt.plot(epochs, losses, label='Training Loss')
plt.plot(epochs, validation_losses, label='Validation Loss',color="r")
plt.legend(loc="upper right")
plt.xlabel('Epoch')
plt.ylabel('Objective Loss')
plt.title('Training Objective Loss Over Epochs')
plt.grid(True)
plt.show()

# Plot top1 acc vs epoch
if plot_acc:
plt.figure(figsize=(10, 5))
plt.scatter(top1_epochs, top1_accuracies, label='Top-1 Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Top-1 Accuracy (%)')
plt.title('Top-1 Accuracy vs Epoch')
plt.grid(True)
plt.legend()
plt.show()

return 0