Skip to content

Commit

Permalink
egs: add fullband gsn
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxiangsnr committed Jan 23, 2024
1 parent 556804d commit 62122de
Show file tree
Hide file tree
Showing 4 changed files with 450 additions and 0 deletions.
106 changes: 106 additions & 0 deletions recipes/intel_ndns/cirm_gsn/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import glob
import os
import re

import numpy as np
import soundfile as sf
from torch.utils.data import Dataset

from audiozen.acoustics.io import subsample


class DNSAudio(Dataset):
def __init__(self, root="./", limit=None, offset=0, sublen=6, train=True) -> None:
"""Audio dataset loader for DNS.
Args:
root: Path of the dataset location, by default './'.
"""
super().__init__()
self.root = root
print(f"Loading dataset from {root}...")
self.noisy_files = glob.glob(root + "noisy/**.wav")

if offset > 0:
self.noisy_files = self.noisy_files[offset:]

if limit:
self.noisy_files = self.noisy_files[:limit]

print(f"Found {len(self.noisy_files)} files.")

self.file_id_from_name = re.compile(r"fileid_(\d+)")
self.snr_from_name = re.compile(r"snr(-?\d+)")
self.target_level_from_name = re.compile(r"tl(-?\d+)")
self.source_info_from_name = re.compile("^(.*?)_snr")

self.train = train
self.sublen = sublen
self.length = len(self.noisy_files)

def __len__(self) -> int:
"""Length of the dataset."""
return self.length

def _get_filenames(self, n):
noisy_file = self.noisy_files[n % self.length]
filename = noisy_file.split(os.sep)[-1]
file_id = int(self.file_id_from_name.findall(filename)[0])
clean_file = self.root + f"clean/clean_fileid_{file_id}.wav"
noise_file = self.root + f"noise/noise_fileid_{file_id}.wav"
snr = int(self.snr_from_name.findall(filename)[0])
target_level = int(self.target_level_from_name.findall(filename)[0])
source_info = self.source_info_from_name.findall(filename)[0]
metadata = {
"snr": snr,
"target_level": target_level,
"source_info": source_info,
}
return noisy_file, clean_file, noise_file, metadata

def __getitem__(self, n):
"""Gets the nth sample from the dataset.
Args:
n: Index of the sample to be retrieved.
Returns:
Noisy audio sample, clean audio sample, noise audio sample, sample metadata.
"""
noisy_file, clean_file, noise_file, metadata = self._get_filenames(n)
noisy_audio, sampling_frequency = sf.read(noisy_file)
clean_audio, _ = sf.read(clean_file)
num_samples = 30 * sampling_frequency # 30 sec data
train_num_samples = self.sublen * sampling_frequency
metadata["fs"] = sampling_frequency

if len(noisy_audio) > num_samples:
noisy_audio = noisy_audio[:num_samples]
else:
noisy_audio = np.concatenate([noisy_audio, np.zeros(num_samples - len(noisy_audio))])
if len(clean_audio) > num_samples:
clean_audio = clean_audio[:num_samples]
else:
clean_audio = np.concatenate([clean_audio, np.zeros(num_samples - len(clean_audio))])

noisy_audio = noisy_audio.astype(np.float32)
clean_audio = clean_audio.astype(np.float32)

if self.train:
noisy_audio, start_position = subsample(
noisy_audio,
subsample_length=train_num_samples,
return_start_idx=True,
)
clean_audio = subsample(
clean_audio,
subsample_length=train_num_samples,
start_idx=start_position,
)

return noisy_audio, clean_audio, noisy_file


if __name__ == "__main__":
train_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/training_set/")
validation_set = DNSAudio(root="../../data/MicrosoftDNS_4_ICASSP/validation_set/")
86 changes: 86 additions & 0 deletions recipes/intel_ndns/cirm_gsn/default.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
[meta]
save_dir = "exp"
description = "Train a model using Generative Adversarial Networks (GANs)"
seed = 20220815

[trainer]
path = "trainer.Trainer"
[trainer.args]
debug = false
max_steps = 0
max_epochs = 200
max_grad_norm = 10
save_max_score = true
save_ckpt_interval = 10
max_patience = 20
plot_norm = true
validation_interval = 10
max_num_checkpoints = 20
scheduler_name = "constant_schedule_with_warmup"
warmup_steps = 0
warmup_ratio = 0.00
gradient_accumulation_steps = 1

[loss_function]
path = "torch.nn.MSELoss"
[loss_function.args]

[optimizer]
path = "torch.optim.AdamW"
[optimizer.args]
lr = 1e-3

[model]
path = "audiozen.models.cirm_gsn.modeling_cirm_gsn.Model"
[model.args]
n_fft = 512
hop_length = 128
win_length = 512
fdrc = 0.5
input_size = 257
hidden_size = 268
num_layers = 4
proj_size = 257
output_activate_function = false
df_order = 3
use_pre_layer_norm_fb = true
bn = true
shared_weights = true
sequence_model = "GSN"
num_spks = 1

[acoustics]
n_fft = 512
hop_length = 128
win_length = 512
sr = 16000

[train_dataset]
path = "dataloader.DNSAudio"
[train_dataset.args]
root = "/datasets/datasets_fullband/training_set/"
limit = false
offset = 0
[train_dataset.dataloader]
batch_size = 64
num_workers = 8
drop_last = true
pin_memory = true

[validate_dataset]
path = "dataloader.DNSAudio"
[validate_dataset.args]
root = "/datasets/datasets_fullband/validation_set/"
train = false
[validate_dataset.dataloader]
batch_size = 16
num_workers = 8

[test_dataset]
path = "dataloader.DNSAudio"
[test_dataset.args]
root = "/nfs/xhao/data/intel_ndns/test_set/"
train = false
[test_dataset.dataloader]
batch_size = 1
num_workers = 0
151 changes: 151 additions & 0 deletions recipes/intel_ndns/cirm_gsn/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import argparse
from math import sqrt
from pathlib import Path

import toml
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.utils import set_seed
from torch.utils.data import DataLoader

from audiozen.logger import init_logging_logger
from audiozen.utils import instantiate


def run(config, resume):
init_logging_logger(config)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(
gradient_accumulation_steps=config["trainer"]["args"]["gradient_accumulation_steps"],
kwargs_handlers=[ddp_kwargs],
)

set_seed(config["meta"]["seed"], device_specific=True)

model = instantiate(config["model"]["path"], args=config["model"]["args"])

optimizer = instantiate(
config["optimizer"]["path"],
args={"params": model.parameters()}
| config["optimizer"]["args"]
| {"lr": config["optimizer"]["args"]["lr"] * sqrt(accelerator.num_processes)},
)

loss_function = instantiate(
config["loss_function"]["path"],
args=config["loss_function"]["args"],
)

(model, optimizer) = accelerator.prepare(model, optimizer)

if "train" in args.mode:
train_dataset = instantiate(config["train_dataset"]["path"], args=config["train_dataset"]["args"])
train_dataloader = DataLoader(
dataset=train_dataset, collate_fn=None, shuffle=True, **config["train_dataset"]["dataloader"]
)
train_dataloader = accelerator.prepare(train_dataloader)

if "train" in args.mode or "validate" in args.mode:
if not isinstance(config["validate_dataset"], list):
config["validate_dataset"] = [config["validate_dataset"]]

validate_dataloaders = []
for validate_config in config["validate_dataset"]:
validate_dataset = instantiate(validate_config["path"], args=validate_config["args"])

validate_dataloaders.append(
accelerator.prepare(
DataLoader(
dataset=validate_dataset,
**validate_config["dataloader"],
)
)
)

if "test" in args.mode:
if not isinstance(config["test_dataset"], list):
config["test_dataset"] = [config["test_dataset"]]

test_dataloaders = []
for test_config in config["test_dataset"]:
test_dataset = instantiate(test_config["path"], args=test_config["args"])

test_dataloaders.append(
accelerator.prepare(
DataLoader(
dataset=test_dataset,
**test_config["dataloader"],
)
)
)

trainer = instantiate(config["trainer"]["path"], initialize=False)(
accelerator=accelerator,
config=config,
resume=resume,
model=model,
optimizer=optimizer,
loss_function=loss_function,
)

for flag in args.mode:
if flag == "train":
trainer.train(train_dataloader, validate_dataloaders)
elif flag == "validate":
trainer.validate(validate_dataloaders)
elif flag == "test":
trainer.test(test_dataloaders, config["meta"]["ckpt_path"])
elif flag == "predict":
raise NotImplementedError("Predict is not implemented yet.")
elif flag == "finetune":
raise NotImplementedError("Finetune is not implemented yet.")
else:
raise ValueError(f"Unknown mode: {flag}.")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio-ZEN")
parser.add_argument(
"-C",
"--configuration",
required=True,
type=str,
help="Configuration (*.toml).",
)
parser.add_argument(
"-M",
"--mode",
nargs="+",
type=str,
default=["train"],
choices=["train", "validate", "test", "predict", "finetune"],
help="Mode of the experiment.",
)
parser.add_argument(
"-R",
"--resume",
action="store_true",
help="Resume the experiment from latest checkpoint.",
)
parser.add_argument(
"--ckpt_path",
type=str,
default=None,
help="Checkpoint path for test. It can be 'best', 'latest', or a path to a checkpoint.",
)

args = parser.parse_args()

config_path = Path(args.configuration).expanduser().absolute()
config = toml.load(config_path.as_posix())

config["meta"]["exp_id"] = config_path.stem
config["meta"]["config_path"] = config_path.as_posix()

if "test" in args.mode:
if args.ckpt_path is None:
raise ValueError("checkpoint path is required for test. Use '--ckpt_path'.")
else:
config["meta"]["ckpt_path"] = args.ckpt_path

run(config, args.resume)
Loading

0 comments on commit 62122de

Please sign in to comment.