From 4f405b3acec77151ba9757c5ef9790c7b454ad6e Mon Sep 17 00:00:00 2001 From: BenAAndrew Date: Wed, 15 Sep 2021 15:52:27 +0100 Subject: [PATCH] Formatting & comment coverage --- dataset/audio_processing.py | 1 + dataset/clip_generator.py | 88 ++++++++++++++++++++++++++++++++++--- dataset/silero_utils.py | 39 ++++++++-------- dataset/transcribe.py | 5 ++- 4 files changed, 107 insertions(+), 26 deletions(-) diff --git a/dataset/audio_processing.py b/dataset/audio_processing.py index e429357..3b99401 100644 --- a/dataset/audio_processing.py +++ b/dataset/audio_processing.py @@ -119,6 +119,7 @@ def cut_audio(input_path, start, end, output_folder): """ def _timestamp_to_filename(timestamp): + """Removes non-numeric characters from timestamp""" return re.sub("[^0-9]", "", timestamp) output_name = f"{_timestamp_to_filename(start)}_{_timestamp_to_filename(end)}.wav" diff --git a/dataset/clip_generator.py b/dataset/clip_generator.py index ef19dac..98dd398 100644 --- a/dataset/clip_generator.py +++ b/dataset/clip_generator.py @@ -14,7 +14,7 @@ from dataset.forced_alignment.search import FuzzySearch from dataset.forced_alignment.audio import DEFAULT_RATE from dataset.audio_processing import change_sample_rate, cut_audio, add_silence -from training import PUNCTUATION, BASE_SYMBOLS +from training import PUNCTUATION MIN_LENGTH = 1.0 @@ -24,10 +24,32 @@ def clip_combiner(audio_path, output_path, clips, max_length): + """ + Combines clips to make them as long as possible without exceeding max length. + + Parameters + ---------- + audio_path : str + Path to audio file (must have been converted using convert_audio) + output_path : str + Path to save audio clips to + clips : list + List of current clips + max_length : float (optional) + Maximum duration of a clip in seconds + + Returns + ------- + (list, list) + List of clips and clip lengths in seconds + """ + def _get_duration(start, end): + """Gets the duration in seconds between two string timestamps""" return (datetime.strptime(end, "%H:%M:%S.%f") - datetime.strptime(start, "%H:%M:%S.%f")).total_seconds() def _join_text(lines): + """Joins list of lines with comma seperation""" return " ".join( [ line + "," if not line[-1] in PUNCTUATION and i != len(lines) - 1 else line @@ -36,6 +58,7 @@ def _join_text(lines): ) def _combine_clip(combined_clip, audio_path, output_path): + """Combines multiple clips to produce one new clip (or returns existing if list contains only one clip)""" if len(combined_clip) > 1: start = combined_clip[0]["start"] end = combined_clip[-1]["end"] @@ -83,6 +106,33 @@ def generate_clips_from_textfile( max_length=MAX_LENGTH, min_confidence=MIN_CONFIDENCE, ): + """ + Generates clips from plain text file. + + Parameters + ---------- + audio_path : str + Path to audio file (must have been converted using convert_audio) + script_path : str + Path to text file + transcription_model : TranscriptionModel + Transcription model + output_path : str + Path to save audio clips to + logging : logging (optional) + Logging object to write logs to + min_length : float (optional) + Minimum duration of a clip in seconds + max_length : float (optional) + Maximum duration of a clip in seconds + min_confidence : float (optional) + Minimum confidence score to generate a clip for + + Returns + ------- + (list, list) + List of clips and clip lengths in seconds + """ logging.info(f"Loading script from {script_path}...") with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file: clean_text = script_file.read().lower().strip().replace("\n", " ").replace(" ", " ") @@ -138,6 +188,33 @@ def generate_clips_from_subtitles( max_length=MAX_LENGTH, min_confidence=MIN_CONFIDENCE, ): + """ + Generates clips from subtitles. + + Parameters + ---------- + audio_path : str + Path to audio file (must have been converted using convert_audio) + subtitle_path : str + Path to subtitle file + transcription_model : TranscriptionModel + Transcription model + output_path : str + Path to save audio clips to + logging : logging (optional) + Logging object to write logs to + min_length : float (optional) + Minimum duration of a clip in seconds + max_length : float (optional) + Maximum duration of a clip in seconds + min_confidence : float (optional) + Minimum confidence score to generate a clip for + + Returns + ------- + (list, list) + List of clips and clip lengths in seconds + """ logging.info("Loading subtitles...") subs = pysrt.open(subtitle_path) total = len(subs) @@ -196,6 +273,7 @@ def clip_generator( ): """ Generates dataset clips & label file. + Also combines clips, adds silence, produces metadata/info & does cleanup. Parameters ---------- @@ -203,8 +281,8 @@ def clip_generator( Path to audio file (must have been converted using convert_audio) script_path : str Path to source text - transcription_model : DeepSpeech - DeepSpeech transcription model + transcription_model : TranscriptionModel + Transcription model forced_alignment_path : str Path to save alignment JSON to output_path : str @@ -340,8 +418,8 @@ def extend_dataset( Path to audio file (must have been converted using convert_audio) script_path : str Path to source text - transcription_model : DeepSpeech - DeepSpeech transcription model + transcription_model : TranscriptionModel + Transcription model forced_alignment_path : str Path to save alignment JSON to output_path : str diff --git a/dataset/silero_utils.py b/dataset/silero_utils.py index 7682c3e..5fb2404 100644 --- a/dataset/silero_utils.py +++ b/dataset/silero_utils.py @@ -15,29 +15,27 @@ from itertools import groupby -class Decoder(): - def __init__(self, - labels: List[str]): +class Decoder: + def __init__(self, labels: List[str]): self.labels = labels - self.blank_idx = self.labels.index('_') - self.space_idx = self.labels.index(' ') + self.blank_idx = self.labels.index("_") + self.space_idx = self.labels.index(" ") - def process(self, - probs, wav_len, word_align): + def process(self, probs, wav_len, word_align): assert len(self.labels) == probs.shape[1] for_string = [] argm = torch.argmax(probs, axis=1) align_list = [[]] for j, i in enumerate(argm): - if i == self.labels.index('2'): + if i == self.labels.index("2"): try: prev = for_string[-1] - for_string.append('$') + for_string.append("$") for_string.append(prev) align_list[-1].append(j) continue except: - for_string.append(' ') + for_string.append(" ") warnings.warn('Token "2" detected a the beginning of sentence, omitting') align_list.append([]) continue @@ -48,7 +46,7 @@ def process(self, else: align_list[-1].append(j) - string = ''.join([x[0] for x in groupby(for_string)]).replace('$', '').strip() + string = "".join([x[0] for x in groupby(for_string)]).replace("$", "").strip() align_list = list(filter(lambda x: x, align_list)) @@ -64,25 +62,26 @@ def process(self, to_move = min(1.5, len(argm) - i) align_word[-1] = align_word[-1] + to_move else: - to_move = min(1.5, (align_list[i+1][0] - align_word[-1]) / 2) + to_move = min(1.5, (align_list[i + 1][0] - align_word[-1]) / 2) align_word[-1] = align_word[-1] + to_move for word, timing in zip(string.split(), align_list): - align_dicts.append({'word': word, - 'start_ts': round(timing[0] * linear_align_coeff, 2), - 'end_ts': round(timing[-1] * linear_align_coeff, 2)}) + align_dicts.append( + { + "word": word, + "start_ts": round(timing[0] * linear_align_coeff, 2), + "end_ts": round(timing[-1] * linear_align_coeff, 2), + } + ) return string, align_dicts return string - def __call__(self, - probs: torch.Tensor, - wav_len: float = 0, - word_align: bool = False): + def __call__(self, probs: torch.Tensor, wav_len: float = 0, word_align: bool = False): return self.process(probs, wav_len, word_align) -def init_jit_model(model_path: str, device: torch.device = torch.device('cpu')): +def init_jit_model(model_path: str, device: torch.device = torch.device("cpu")): torch.set_grad_enabled(False) model = torch.jit.load(model_path, map_location=device) model.eval() diff --git a/dataset/transcribe.py b/dataset/transcribe.py index 7dc1bfe..bc963ed 100644 --- a/dataset/transcribe.py +++ b/dataset/transcribe.py @@ -89,7 +89,10 @@ def __init__(self, language="en"): self.model, self.decoder = init_jit_model(model, self.device) else: self.model, self.decoder, _ = torch.hub.load( - repo_or_dir="snakers4/silero-models", model="silero_stt", language=language, device=self.device, source="github" + repo_or_dir="snakers4/silero-models", + model="silero_stt", + language=language, + device=self.device, ) def load_audio(self, path):