Skip to content

Commit

Permalink
Merge pull request #267 from blisc/u_logging_update
Browse files Browse the repository at this point in the history
Logging update
  • Loading branch information
blisc authored Jan 21, 2020
2 parents 35209cb + 0200c9e commit bc0d7b1
Show file tree
Hide file tree
Showing 61 changed files with 578 additions and 682 deletions.
3 changes: 1 addition & 2 deletions collections/nemo_asr/nemo_asr/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def __init__(
frame_splicing=frame_splicing,
stft_conv=stft_conv,
pad_value=pad_value,
mag_power=mag_power,
logger=self._logger
mag_power=mag_power
)
self.featurizer.to(self._device)

Expand Down
9 changes: 4 additions & 5 deletions collections/nemo_asr/nemo_asr/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial
import torch

import nemo
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
from nemo.core.neural_types import *
Expand Down Expand Up @@ -149,14 +150,13 @@ def __init__(
'trim': trim_silence,
'bos_id': bos_id,
'eos_id': eos_id,
'logger': self._logger,
'load_audio': load_audio}

self._dataset = AudioDataset(**dataset_params)

# Set up data loader
if self._placement == DeviceType.AllGpu:
self._logger.info('Parallelizing DATALAYER')
nemo.logging.info('Parallelizing DATALAYER')
sampler = torch.utils.data.distributed.DistributedSampler(
self._dataset)
else:
Expand Down Expand Up @@ -275,13 +275,12 @@ def __init__(
'labels': labels,
'min_duration': min_duration,
'max_duration': max_duration,
'normalize': normalize_transcripts,
'logger': self._logger}
'normalize': normalize_transcripts}
self._dataset = KaldiFeatureDataset(**dataset_params)

# Set up data loader
if self._placement == DeviceType.AllGpu:
self._logger.info('Parallelizing DATALAYER')
nemo.logging.info('Parallelizing DATALAYER')
sampler = torch.utils.data.distributed.DistributedSampler(
self._dataset)
else:
Expand Down
44 changes: 13 additions & 31 deletions collections/nemo_asr/nemo_asr/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2019 NVIDIA Corporation

import torch
import nemo

from .metrics import word_error_rate

Expand Down Expand Up @@ -31,8 +32,7 @@ def __ctc_decoder_predictions_tensor(tensor, labels):
def monitor_asr_train_progress(tensors: list,
labels: list,
eval_metric='WER',
tb_logger=None,
logger=None):
tb_logger=None):
"""
Takes output of greedy ctc decoder and performs ctc decoding algorithm to
remove duplicates and special symbol. Prints sample to screen, computes
Expand All @@ -42,7 +42,6 @@ def monitor_asr_train_progress(tensors: list,
labels: A list of labels
eval_metric: An optional string from 'WER', 'CER'. Defaults to 'WER'.
tb_logger: Tensorboard logging object
logger:
Returns:
None
"""
Expand Down Expand Up @@ -72,16 +71,10 @@ def monitor_asr_train_progress(tensors: list,
wer = word_error_rate(hypotheses, references, use_cer=use_cer)
if tb_logger is not None:
tb_logger.add_scalar(tag, wer)
if logger:
logger.info(f'Loss: {tensors[0]}')
logger.info(f'{tag}: {wer*100 : 5.2f}%')
logger.info(f'Prediction: {hypotheses[0]}')
logger.info(f'Reference: {references[0]}')
else:
print(f'Loss: {tensors[0]}')
print(f'{tag}: {wer*100 : 5.2f}%')
print(f'Prediction: {hypotheses[0]}')
print(f'Reference: {references[0]}')
nemo.logging.info(f'Loss: {tensors[0]}')
nemo.logging.info(f'{tag}: {wer*100 : 5.2f}%')
nemo.logging.info(f'Prediction: {hypotheses[0]}')
nemo.logging.info(f'Reference: {references[0]}')


def __gather_losses(losses_list: list) -> list:
Expand Down Expand Up @@ -146,8 +139,7 @@ def process_evaluation_batch(tensors: dict, global_vars: dict, labels: list):

def process_evaluation_epoch(global_vars: dict,
eval_metric='WER',
tag=None,
logger=None):
tag=None):
"""
Calculates the aggregated loss and WER across the entire evaluation dataset
"""
Expand All @@ -165,24 +157,14 @@ def process_evaluation_epoch(global_vars: dict,
use_cer=use_cer)

if tag is None:
if logger:
logger.info(f"==========>>>>>>Evaluation Loss: {eloss}")
logger.info(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
else:
print(f"==========>>>>>>Evaluation Loss: {eloss}")
print(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
nemo.logging.info(f"==========>>>>>>Evaluation Loss: {eloss}")
nemo.logging.info(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
return {"Evaluation_Loss": eloss, f"Evaluation_{eval_metric}": wer}
else:
if logger:
logger.info(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
logger.info(f"==========>>>>>>Evaluation {eval_metric} {tag}: "
f"{wer*100 : 5.2f}%")
else:
print(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
print(f"==========>>>>>>Evaluation {eval_metric} {tag}:"
f" {wer*100 : 5.2f}%")
nemo.logging.info(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
nemo.logging.info(f"==========>>>>>>Evaluation {eval_metric} {tag}: "
f"{wer*100 : 5.2f}%")
return {f"Evaluation_Loss_{tag}": eloss,
f"Evaluation_{eval_metric}_{tag}": wer}

Expand Down
13 changes: 6 additions & 7 deletions collections/nemo_asr/nemo_asr/las/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pprint import pformat

import torch
import nemo
from nemo.backends.pytorch.common.metrics import char_lm_metrics

from nemo_asr.metrics import word_error_rate
Expand Down Expand Up @@ -55,7 +56,7 @@ def process_evaluation_batch(tensors, global_vars, labels, specials,

def process_evaluation_epoch(global_vars,
metrics=('loss', 'bpc', 'ppl'), calc_wer=False,
logger=None, mode='eval', tag='none'):
mode='eval', tag='none'):
tag = '_'.join(tag.lower().strip().split())
return_dict = {}
for metric in metrics:
Expand All @@ -70,17 +71,15 @@ def process_evaluation_epoch(global_vars,
transcript_texts = list(chain(*global_vars['transcript_texts']))
prediction_texts = list(chain(*global_vars['prediction_texts']))

if logger:
logger.info(f'Ten examples (transcripts and predictions)')
logger.info(transcript_texts[:10])
logger.info(prediction_texts[:10])
nemo.logging.info(f'Ten examples (transcripts and predictions)')
nemo.logging.info(transcript_texts[:10])
nemo.logging.info(prediction_texts[:10])

wer = word_error_rate(hypotheses=prediction_texts,
references=transcript_texts)
return_dict[f'metric/{mode}_wer_{tag}'] = wer

if logger:
logger.info(pformat(return_dict))
nemo.logging.info(pformat(return_dict))

return return_dict

Expand Down
30 changes: 14 additions & 16 deletions collections/nemo_asr/nemo_asr/parts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from torch.utils.data import Dataset

import nemo

from .manifest import ManifestBase, ManifestEN


Expand Down Expand Up @@ -131,7 +133,6 @@ def __init__(
trim=False,
bos_id=None,
eos_id=None,
logger=False,
load_audio=True,
manifest_class=ManifestEN):
m_paths = manifest_filepath.split(',')
Expand All @@ -141,19 +142,17 @@ def __init__(
max_utts=max_utts,
blank_index=blank_index,
unk_index=unk_index,
normalize=normalize,
logger=logger)
normalize=normalize)
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
if logger:
logger.info(
"Dataset loaded with {0:.2f} hours. Filtered {1:.2f} "
"hours.".format(
self.manifest.duration / 3600,
self.manifest.filtered_duration / 3600))
nemo.logging.info(
"Dataset loaded with {0:.2f} hours. Filtered {1:.2f} "
"hours.".format(
self.manifest.duration / 3600,
self.manifest.filtered_duration / 3600))

def __getitem__(self, index):
sample = self.manifest[index]
Expand Down Expand Up @@ -214,8 +213,7 @@ def __init__(
unk_index=-1,
blank_index=-1,
normalize=True,
eos_id=None,
logger=None):
eos_id=None):
self.eos_id = eos_id
self.unk_index = unk_index
self.blank_index = blank_index
Expand Down Expand Up @@ -245,8 +243,8 @@ def __init__(
f"KaldiFeatureDataset max_duration or min_duration is set but"
f" utt2dur file not found in {kaldi_dir}."
)
elif logger:
logger.info(
else:
nemo.logging.info(
f"Did not find utt2dur when loading data from "
f"{kaldi_dir}. Skipping dataset duration calculations."
)
Expand All @@ -265,7 +263,7 @@ def __init__(
text = line[split_idx:].strip()
if normalize:
text = ManifestEN.normalize_text(
text, labels, logger=logger)
text, labels)
dur = id2dur[utt_id] if id2dur else None

# Filter by duration if specified & utt2dur exists
Expand Down Expand Up @@ -295,9 +293,9 @@ def __init__(
print(f"Stop parsing due to max_utts ({max_utts})")
break

if logger and id2dur:
if id2dur:
# utt2dur durations are in seconds
logger.info(
nemo.logging.info(
f"Dataset loaded with {duration/60 : .2f} hours. "
f"Filtered {filtered_duration/60 : .2f} hours.")

Expand Down
13 changes: 4 additions & 9 deletions collections/nemo_asr/nemo_asr/parts/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .segment import AudioSegment
from torch_stft import STFT

import nemo

CONSTANT = 1e-5


Expand Down Expand Up @@ -127,7 +129,6 @@ def __init__(
stft_conv=False,
pad_value=0,
mag_power=2.,
logger=None
):
super(FilterbankFeatures, self).__init__()
if (n_window_size is None or n_window_stride is None
Expand All @@ -137,21 +138,15 @@ def __init__(
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints.")
if logger:
logger.info(f"PADDING: {pad_to}")
else:
print(f"PADDING: {pad_to}")
nemo.logging.info(f"PADDING: {pad_to}")

self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_conv = stft_conv

if stft_conv:
if logger:
logger.info("STFT using conv")
else:
print("STFT using conv")
nemo.logging.info("STFT using conv")

# Create helper class to patch forward func for use with AMP
class STFTPatch(STFT):
Expand Down
23 changes: 8 additions & 15 deletions collections/nemo_asr/nemo_asr/parts/manifest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Taken straight from Patter https://github.com/ryanleary/patter
# TODO: review, and copyright and fix/add comments
import json
import nemo
import string

from nemo.utils import get_logger
from .cleaners import clean_text


Expand All @@ -17,8 +17,7 @@ def __init__(self,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
logger=None):
normalize=True):
self.min_duration = min_duration
self.max_duration = max_duration
self.sort_by_duration = sort_by_duration
Expand All @@ -27,9 +26,6 @@ def __init__(self,
self.unk_index = unk_index
self.normalize = normalize
self.labels_map = {label: i for i, label in enumerate(labels)}
self.logger = logger
if logger is None:
self.logger = get_logger('')

data = []
duration = 0.0
Expand All @@ -53,9 +49,9 @@ def __init__(self,
filtered_duration += item['duration']
continue
if normalize:
text = self.normalize_text(text, labels, logger=self.logger)
text = self.normalize_text(text, labels)
if not isinstance(text, str):
self.logger.warning(
nemo.logging.warning(
"WARNING: Got transcript: {}. It is not a "
"string. Dropping data point".format(text)
)
Expand All @@ -69,7 +65,7 @@ def __init__(self,

# support files using audio_filename
if 'audio_filename' in item and 'audio_filepath' not in item:
self.logger.warning(
nemo.logging.warning(
"Malformed manifest: The key audio_filepath was not "
"found in the manifest. Using audio_filename instead."
)
Expand All @@ -79,7 +75,7 @@ def __init__(self,
duration += item['duration']

if max_utts > 0 and len(data) >= max_utts:
self.logger.info(
nemo.logging.info(
'Stop parsing due to max_utts ({})'.format(max_utts))
break

Expand Down Expand Up @@ -155,7 +151,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def normalize_text(text, labels, logger=None):
def normalize_text(text, labels):
# Punctuation to remove
punctuation = string.punctuation
# Define punctuation that will be handled by text cleaner
Expand Down Expand Up @@ -183,10 +179,7 @@ def normalize_text(text, labels, logger=None):
try:
text = clean_text(text, table, punctuation_to_replace)
except BaseException:
if logger:
logger.warning("WARNING: Normalizing {} failed".format(text))
else:
print("WARNING: Normalizing {} failed".format(text))
nemo.logging.warning("WARNING: Normalizing {} failed".format(text))
return None

return text
Loading

0 comments on commit bc0d7b1

Please sign in to comment.