Skip to content

Commit

Permalink
egs: update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
haoxiangsnr authored Jan 16, 2024
1 parent a15cece commit 4be515e
Showing 1 changed file with 36 additions and 71 deletions.
107 changes: 36 additions & 71 deletions recipes/intel_ndns/spiking_fullsubnet/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pandas as pd
import torch
from accelerate.logging import get_logger
from tqdm import tqdm

Expand All @@ -18,8 +17,9 @@ def __init__(self, *args, **kwargs):
self.stoi = STOI(sr=self.sr)
self.pesq_wb = PESQ(sr=self.sr, mode="wb")
self.pesq_nb = PESQ(sr=self.sr, mode="nb")
self.sisnr_loss = SISNRLoss(return_neg=False)
self.si_sdr = SISDR()
self.sisnr_loss = SISNRLoss()
self.north_star_metric = "si_sdr"

def training_step(self, batch, batch_idx):
self.optimizer.zero_grad()
Expand Down Expand Up @@ -47,90 +47,55 @@ def training_step(self, batch, batch_idx):
"loss_sdr_norm": loss_sdr_norm,
}

def training_epoch_end(self, training_epoch_output):
# Compute mean loss on all loss items on epoch
for key in training_epoch_output[0].keys():
loss_items = [step_out[key] for step_out in training_epoch_output]
loss_mean = torch.mean(torch.tensor(loss_items))
def validation_step(self, batch, batch_idx, dataloader_idx=0):
mix_y, ref_y, id = batch
est_y, *_ = self.model(mix_y)

if self.accelerator.is_local_main_process:
logger.info(f"Loss '{key}' on epoch {self.state.epochs_trained}: {loss_mean}")
self.writer.add_scalar(f"Train_Epoch/{key}", loss_mean, self.state.epochs_trained)
if len(id) != 1:
raise ValueError(f"Expected batch size 1 during validation, got {len(id)}")

def validation_step(self, batch, batch_idx, dataloader_idx=0):
noisy_y, clean_y, noisy_file = batch
enhanced_y, *_ = self.model(noisy_y)

# save enhanced audio
# stem = Path(noisy_file[0]).stem
# enhanced_dir = self.enhanced_dir / f"dataloader_{dataloader_idx}"
# enhanced_dir.mkdir(exist_ok=True, parents=True)
# enhanced_fpath = enhanced_dir / f"{stem}.wav"
# save_wav(enhanced_y, enhanced_fpath.as_posix(), self.sr)

# detach and move to cpu
# synops = compute_synops(
# fb_out,
# sb_out,
# shared_weights=self.config["model_g"]["args"]["shared_weights"],
# )
# neuron_ops = compute_neuronops(fb_out, sb_out)

# to tensor
# synops = torch.tensor([synops], device=self.accelerator.device).unsqueeze(0)
# synops = synops.repeat(enhanced_y.shape[0], 1)
# neuron_ops = torch.tensor([neuron_ops], device=self.accelerator.device).unsqueeze(0)
# neuron_ops = neuron_ops.repeat(enhanced_y.shape[0], 1)

return noisy_y, clean_y, enhanced_y # , synops, neuron_ops

def compute_metrics(self, dataloader_idx, step_out):
noisy, clean, enhanced = step_out

si_sdr = self.si_sdr(enhanced, clean)
dns_mos = self.dns_mos(enhanced)

return si_sdr | dns_mos

def compute_batch_metrics(self, dataloader_idx, step_out):
noisy, clean, enhanced = step_out
assert noisy.ndim == clean.ndim == enhanced.ndim == 2

# [num_ranks * batch_size, num_samples]
results = []
for i in range(noisy.shape[0]):
enhanced_i = enhanced[i, :]
clean_i = clean[i, :]
noisy_i = noisy[i, :]
results.append(
self.compute_metrics(
dataloader_idx,
(noisy_i, clean_i, enhanced_i),
)
)
# calculate metrics
mix_y = mix_y.squeeze(0).detach().cpu().numpy()
ref_y = ref_y.squeeze(0).detach().cpu().numpy()
est_y = est_y.squeeze(0).detach().cpu().numpy()

si_sdr = self.si_sdr(est_y, ref_y)
dns_mos = self.dns_mos(est_y)

return results
out = si_sdr | dns_mos
return [out]

def validation_epoch_end(self, outputs):
def validation_epoch_end(self, outputs, log_to_tensorboard=True):
score = 0.0

for dataloader_idx, dataloader_outputs in enumerate(outputs):
logger.info(f"Computing metrics on epoch {self.state.epochs_trained} for dataloader {dataloader_idx}...")

rows = []
for step_out in tqdm(dataloader_outputs):
rows += self.compute_batch_metrics(dataloader_idx, step_out)
loss_dict_list = []
for step_loss_dict_list in tqdm(dataloader_outputs):
loss_dict_list.extend(step_loss_dict_list)

df_metrics = pd.DataFrame(rows)
df_metrics = pd.DataFrame(loss_dict_list)

# Compute mean of all metrics
df_metrics_mean = df_metrics.mean(numeric_only=True)
df_metrics_mean_df = df_metrics_mean.to_frame().T

logger.info(f"\n{df_metrics_mean_df.to_markdown()}")
time_now = self._get_time_now()
df_metrics.to_csv(
self.metrics_dir / f"dl_{dataloader_idx}_epoch_{self.state.epochs_trained}_{time_now}.csv",
index=False,
)
df_metrics_mean_df.to_csv(
self.metrics_dir / f"dl_{dataloader_idx}_epoch_{self.state.epochs_trained}_{time_now}_mean.csv",
index=False,
)

score += df_metrics_mean["OVRL"]
logger.info(f"\n{df_metrics_mean_df.to_markdown()}")
score += df_metrics_mean[self.north_star_metric]

for metric, value in df_metrics_mean.items():
self.writer.add_scalar(f"metrics_{dataloader_idx}/{metric}", value, self.state.epochs_trained)
if log_to_tensorboard:
for metric, value in df_metrics_mean.items():
self.writer.add_scalar(f"metrics_{dataloader_idx}/{metric}", value, self.state.epochs_trained)

return score

0 comments on commit 4be515e

Please sign in to comment.