-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
556804d
commit 62122de
Showing
4 changed files
with
450 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import glob | ||
import os | ||
import re | ||
|
||
import numpy as np | ||
import soundfile as sf | ||
from torch.utils.data import Dataset | ||
|
||
from audiozen.acoustics.io import subsample | ||
|
||
|
||
class DNSAudio(Dataset): | ||
def __init__(self, root="./", limit=None, offset=0, sublen=6, train=True) -> None: | ||
"""Audio dataset loader for DNS. | ||
Args: | ||
root: Path of the dataset location, by default './'. | ||
""" | ||
super().__init__() | ||
self.root = root | ||
print(f"Loading dataset from {root}...") | ||
self.noisy_files = glob.glob(root + "noisy/**.wav") | ||
|
||
if offset > 0: | ||
self.noisy_files = self.noisy_files[offset:] | ||
|
||
if limit: | ||
self.noisy_files = self.noisy_files[:limit] | ||
|
||
print(f"Found {len(self.noisy_files)} files.") | ||
|
||
self.file_id_from_name = re.compile(r"fileid_(\d+)") | ||
self.snr_from_name = re.compile(r"snr(-?\d+)") | ||
self.target_level_from_name = re.compile(r"tl(-?\d+)") | ||
self.source_info_from_name = re.compile("^(.*?)_snr") | ||
|
||
self.train = train | ||
self.sublen = sublen | ||
self.length = len(self.noisy_files) | ||
|
||
def __len__(self) -> int: | ||
"""Length of the dataset.""" | ||
return self.length | ||
|
||
def _get_filenames(self, n): | ||
noisy_file = self.noisy_files[n % self.length] | ||
filename = noisy_file.split(os.sep)[-1] | ||
file_id = int(self.file_id_from_name.findall(filename)[0]) | ||
clean_file = self.root + f"clean/clean_fileid_{file_id}.wav" | ||
noise_file = self.root + f"noise/noise_fileid_{file_id}.wav" | ||
snr = int(self.snr_from_name.findall(filename)[0]) | ||
target_level = int(self.target_level_from_name.findall(filename)[0]) | ||
source_info = self.source_info_from_name.findall(filename)[0] | ||
metadata = { | ||
"snr": snr, | ||
"target_level": target_level, | ||
"source_info": source_info, | ||
} | ||
return noisy_file, clean_file, noise_file, metadata | ||
|
||
def __getitem__(self, n): | ||
"""Gets the nth sample from the dataset. | ||
Args: | ||
n: Index of the sample to be retrieved. | ||
Returns: | ||
Noisy audio sample, clean audio sample, noise audio sample, sample metadata. | ||
""" | ||
noisy_file, clean_file, noise_file, metadata = self._get_filenames(n) | ||
noisy_audio, sampling_frequency = sf.read(noisy_file) | ||
clean_audio, _ = sf.read(clean_file) | ||
num_samples = 30 * sampling_frequency # 30 sec data | ||
train_num_samples = self.sublen * sampling_frequency | ||
metadata["fs"] = sampling_frequency | ||
|
||
if len(noisy_audio) > num_samples: | ||
noisy_audio = noisy_audio[:num_samples] | ||
else: | ||
noisy_audio = np.concatenate([noisy_audio, np.zeros(num_samples - len(noisy_audio))]) | ||
if len(clean_audio) > num_samples: | ||
clean_audio = clean_audio[:num_samples] | ||
else: | ||
clean_audio = np.concatenate([clean_audio, np.zeros(num_samples - len(clean_audio))]) | ||
|
||
noisy_audio = noisy_audio.astype(np.float32) | ||
clean_audio = clean_audio.astype(np.float32) | ||
|
||
if self.train: | ||
noisy_audio, start_position = subsample( | ||
noisy_audio, | ||
subsample_length=train_num_samples, | ||
return_start_idx=True, | ||
) | ||
clean_audio = subsample( | ||
clean_audio, | ||
subsample_length=train_num_samples, | ||
start_idx=start_position, | ||
) | ||
|
||
return noisy_audio, clean_audio, noisy_file | ||
|
||
|
||
if __name__ == "__main__": | ||
train_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/training_set/") | ||
validation_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/validation_set/") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
[meta] | ||
save_dir = "exp" | ||
description = "Train a model using Generative Adversarial Networks (GANs)" | ||
seed = 20220815 | ||
|
||
[trainer] | ||
path = "trainer.Trainer" | ||
[trainer.args] | ||
debug = false | ||
max_steps = 0 | ||
max_epochs = 200 | ||
max_grad_norm = 10 | ||
save_max_score = true | ||
save_ckpt_interval = 10 | ||
max_patience = 20 | ||
plot_norm = true | ||
validation_interval = 10 | ||
max_num_checkpoints = 20 | ||
scheduler_name = "constant_schedule_with_warmup" | ||
warmup_steps = 0 | ||
warmup_ratio = 0.00 | ||
gradient_accumulation_steps = 1 | ||
|
||
[loss_function] | ||
path = "torch.nn.MSELoss" | ||
[loss_function.args] | ||
|
||
[optimizer] | ||
path = "torch.optim.AdamW" | ||
[optimizer.args] | ||
lr = 1e-3 | ||
|
||
[model] | ||
path = "audiozen.models.cirm_gsn.modeling_cirm_gsn.Model" | ||
[model.args] | ||
n_fft = 512 | ||
hop_length = 128 | ||
win_length = 512 | ||
fdrc = 0.5 | ||
input_size = 257 | ||
hidden_size = 268 | ||
num_layers = 4 | ||
proj_size = 257 | ||
output_activate_function = false | ||
df_order = 3 | ||
use_pre_layer_norm_fb = true | ||
bn = true | ||
shared_weights = true | ||
sequence_model = "GSN" | ||
num_spks = 1 | ||
|
||
[acoustics] | ||
n_fft = 512 | ||
hop_length = 128 | ||
win_length = 512 | ||
sr = 16000 | ||
|
||
[train_dataset] | ||
path = "dataloader.DNSAudio" | ||
[train_dataset.args] | ||
root = "/datasets/datasets_fullband/training_set/" | ||
limit = false | ||
offset = 0 | ||
[train_dataset.dataloader] | ||
batch_size = 64 | ||
num_workers = 8 | ||
drop_last = true | ||
pin_memory = true | ||
|
||
[validate_dataset] | ||
path = "dataloader.DNSAudio" | ||
[validate_dataset.args] | ||
root = "/datasets/datasets_fullband/validation_set/" | ||
train = false | ||
[validate_dataset.dataloader] | ||
batch_size = 16 | ||
num_workers = 8 | ||
|
||
[test_dataset] | ||
path = "dataloader.DNSAudio" | ||
[test_dataset.args] | ||
root = "/nfs/xhao/data/intel_ndns/test_set/" | ||
train = false | ||
[test_dataset.dataloader] | ||
batch_size = 1 | ||
num_workers = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import argparse | ||
from math import sqrt | ||
from pathlib import Path | ||
|
||
import toml | ||
from accelerate import Accelerator, DistributedDataParallelKwargs | ||
from accelerate.utils import set_seed | ||
from torch.utils.data import DataLoader | ||
|
||
from audiozen.logger import init_logging_logger | ||
from audiozen.utils import instantiate | ||
|
||
|
||
def run(config, resume): | ||
init_logging_logger(config) | ||
|
||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | ||
accelerator = Accelerator( | ||
gradient_accumulation_steps=config["trainer"]["args"]["gradient_accumulation_steps"], | ||
kwargs_handlers=[ddp_kwargs], | ||
) | ||
|
||
set_seed(config["meta"]["seed"], device_specific=True) | ||
|
||
model = instantiate(config["model"]["path"], args=config["model"]["args"]) | ||
|
||
optimizer = instantiate( | ||
config["optimizer"]["path"], | ||
args={"params": model.parameters()} | ||
| config["optimizer"]["args"] | ||
| {"lr": config["optimizer"]["args"]["lr"] * sqrt(accelerator.num_processes)}, | ||
) | ||
|
||
loss_function = instantiate( | ||
config["loss_function"]["path"], | ||
args=config["loss_function"]["args"], | ||
) | ||
|
||
(model, optimizer) = accelerator.prepare(model, optimizer) | ||
|
||
if "train" in args.mode: | ||
train_dataset = instantiate(config["train_dataset"]["path"], args=config["train_dataset"]["args"]) | ||
train_dataloader = DataLoader( | ||
dataset=train_dataset, collate_fn=None, shuffle=True, **config["train_dataset"]["dataloader"] | ||
) | ||
train_dataloader = accelerator.prepare(train_dataloader) | ||
|
||
if "train" in args.mode or "validate" in args.mode: | ||
if not isinstance(config["validate_dataset"], list): | ||
config["validate_dataset"] = [config["validate_dataset"]] | ||
|
||
validate_dataloaders = [] | ||
for validate_config in config["validate_dataset"]: | ||
validate_dataset = instantiate(validate_config["path"], args=validate_config["args"]) | ||
|
||
validate_dataloaders.append( | ||
accelerator.prepare( | ||
DataLoader( | ||
dataset=validate_dataset, | ||
**validate_config["dataloader"], | ||
) | ||
) | ||
) | ||
|
||
if "test" in args.mode: | ||
if not isinstance(config["test_dataset"], list): | ||
config["test_dataset"] = [config["test_dataset"]] | ||
|
||
test_dataloaders = [] | ||
for test_config in config["test_dataset"]: | ||
test_dataset = instantiate(test_config["path"], args=test_config["args"]) | ||
|
||
test_dataloaders.append( | ||
accelerator.prepare( | ||
DataLoader( | ||
dataset=test_dataset, | ||
**test_config["dataloader"], | ||
) | ||
) | ||
) | ||
|
||
trainer = instantiate(config["trainer"]["path"], initialize=False)( | ||
accelerator=accelerator, | ||
config=config, | ||
resume=resume, | ||
model=model, | ||
optimizer=optimizer, | ||
loss_function=loss_function, | ||
) | ||
|
||
for flag in args.mode: | ||
if flag == "train": | ||
trainer.train(train_dataloader, validate_dataloaders) | ||
elif flag == "validate": | ||
trainer.validate(validate_dataloaders) | ||
elif flag == "test": | ||
trainer.test(test_dataloaders, config["meta"]["ckpt_path"]) | ||
elif flag == "predict": | ||
raise NotImplementedError("Predict is not implemented yet.") | ||
elif flag == "finetune": | ||
raise NotImplementedError("Finetune is not implemented yet.") | ||
else: | ||
raise ValueError(f"Unknown mode: {flag}.") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Audio-ZEN") | ||
parser.add_argument( | ||
"-C", | ||
"--configuration", | ||
required=True, | ||
type=str, | ||
help="Configuration (*.toml).", | ||
) | ||
parser.add_argument( | ||
"-M", | ||
"--mode", | ||
nargs="+", | ||
type=str, | ||
default=["train"], | ||
choices=["train", "validate", "test", "predict", "finetune"], | ||
help="Mode of the experiment.", | ||
) | ||
parser.add_argument( | ||
"-R", | ||
"--resume", | ||
action="store_true", | ||
help="Resume the experiment from latest checkpoint.", | ||
) | ||
parser.add_argument( | ||
"--ckpt_path", | ||
type=str, | ||
default=None, | ||
help="Checkpoint path for test. It can be 'best', 'latest', or a path to a checkpoint.", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
config_path = Path(args.configuration).expanduser().absolute() | ||
config = toml.load(config_path.as_posix()) | ||
|
||
config["meta"]["exp_id"] = config_path.stem | ||
config["meta"]["config_path"] = config_path.as_posix() | ||
|
||
if "test" in args.mode: | ||
if args.ckpt_path is None: | ||
raise ValueError("checkpoint path is required for test. Use '--ckpt_path'.") | ||
else: | ||
config["meta"]["ckpt_path"] = args.ckpt_path | ||
|
||
run(config, args.resume) |
Oops, something went wrong.