Skip to content
This repository has been archived by the owner on Sep 29, 2023. It is now read-only.

Commit

Permalink
Formatting & comment coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
BenAAndrew committed Sep 15, 2021
1 parent f7bb444 commit 4f405b3
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 26 deletions.
1 change: 1 addition & 0 deletions dataset/audio_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def cut_audio(input_path, start, end, output_folder):
"""

def _timestamp_to_filename(timestamp):
"""Removes non-numeric characters from timestamp"""
return re.sub("[^0-9]", "", timestamp)

output_name = f"{_timestamp_to_filename(start)}_{_timestamp_to_filename(end)}.wav"
Expand Down
88 changes: 83 additions & 5 deletions dataset/clip_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataset.forced_alignment.search import FuzzySearch
from dataset.forced_alignment.audio import DEFAULT_RATE
from dataset.audio_processing import change_sample_rate, cut_audio, add_silence
from training import PUNCTUATION, BASE_SYMBOLS
from training import PUNCTUATION


MIN_LENGTH = 1.0
Expand All @@ -24,10 +24,32 @@


def clip_combiner(audio_path, output_path, clips, max_length):
"""
Combines clips to make them as long as possible without exceeding max length.
Parameters
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
output_path : str
Path to save audio clips to
clips : list
List of current clips
max_length : float (optional)
Maximum duration of a clip in seconds
Returns
-------
(list, list)
List of clips and clip lengths in seconds
"""

def _get_duration(start, end):
"""Gets the duration in seconds between two string timestamps"""
return (datetime.strptime(end, "%H:%M:%S.%f") - datetime.strptime(start, "%H:%M:%S.%f")).total_seconds()

def _join_text(lines):
"""Joins list of lines with comma seperation"""
return " ".join(
[
line + "," if not line[-1] in PUNCTUATION and i != len(lines) - 1 else line
Expand All @@ -36,6 +58,7 @@ def _join_text(lines):
)

def _combine_clip(combined_clip, audio_path, output_path):
"""Combines multiple clips to produce one new clip (or returns existing if list contains only one clip)"""
if len(combined_clip) > 1:
start = combined_clip[0]["start"]
end = combined_clip[-1]["end"]
Expand Down Expand Up @@ -83,6 +106,33 @@ def generate_clips_from_textfile(
max_length=MAX_LENGTH,
min_confidence=MIN_CONFIDENCE,
):
"""
Generates clips from plain text file.
Parameters
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
script_path : str
Path to text file
transcription_model : TranscriptionModel
Transcription model
output_path : str
Path to save audio clips to
logging : logging (optional)
Logging object to write logs to
min_length : float (optional)
Minimum duration of a clip in seconds
max_length : float (optional)
Maximum duration of a clip in seconds
min_confidence : float (optional)
Minimum confidence score to generate a clip for
Returns
-------
(list, list)
List of clips and clip lengths in seconds
"""
logging.info(f"Loading script from {script_path}...")
with open(script_path, "r", encoding=CHARACTER_ENCODING) as script_file:
clean_text = script_file.read().lower().strip().replace("\n", " ").replace(" ", " ")
Expand Down Expand Up @@ -138,6 +188,33 @@ def generate_clips_from_subtitles(
max_length=MAX_LENGTH,
min_confidence=MIN_CONFIDENCE,
):
"""
Generates clips from subtitles.
Parameters
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
subtitle_path : str
Path to subtitle file
transcription_model : TranscriptionModel
Transcription model
output_path : str
Path to save audio clips to
logging : logging (optional)
Logging object to write logs to
min_length : float (optional)
Minimum duration of a clip in seconds
max_length : float (optional)
Maximum duration of a clip in seconds
min_confidence : float (optional)
Minimum confidence score to generate a clip for
Returns
-------
(list, list)
List of clips and clip lengths in seconds
"""
logging.info("Loading subtitles...")
subs = pysrt.open(subtitle_path)
total = len(subs)
Expand Down Expand Up @@ -196,15 +273,16 @@ def clip_generator(
):
"""
Generates dataset clips & label file.
Also combines clips, adds silence, produces metadata/info & does cleanup.
Parameters
----------
audio_path : str
Path to audio file (must have been converted using convert_audio)
script_path : str
Path to source text
transcription_model : DeepSpeech
DeepSpeech transcription model
transcription_model : TranscriptionModel
Transcription model
forced_alignment_path : str
Path to save alignment JSON to
output_path : str
Expand Down Expand Up @@ -340,8 +418,8 @@ def extend_dataset(
Path to audio file (must have been converted using convert_audio)
script_path : str
Path to source text
transcription_model : DeepSpeech
DeepSpeech transcription model
transcription_model : TranscriptionModel
Transcription model
forced_alignment_path : str
Path to save alignment JSON to
output_path : str
Expand Down
39 changes: 19 additions & 20 deletions dataset/silero_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,27 @@
from itertools import groupby


class Decoder():
def __init__(self,
labels: List[str]):
class Decoder:
def __init__(self, labels: List[str]):
self.labels = labels
self.blank_idx = self.labels.index('_')
self.space_idx = self.labels.index(' ')
self.blank_idx = self.labels.index("_")
self.space_idx = self.labels.index(" ")

def process(self,
probs, wav_len, word_align):
def process(self, probs, wav_len, word_align):
assert len(self.labels) == probs.shape[1]
for_string = []
argm = torch.argmax(probs, axis=1)
align_list = [[]]
for j, i in enumerate(argm):
if i == self.labels.index('2'):
if i == self.labels.index("2"):
try:
prev = for_string[-1]
for_string.append('$')
for_string.append("$")
for_string.append(prev)
align_list[-1].append(j)
continue
except:
for_string.append(' ')
for_string.append(" ")
warnings.warn('Token "2" detected a the beginning of sentence, omitting')
align_list.append([])
continue
Expand All @@ -48,7 +46,7 @@ def process(self,
else:
align_list[-1].append(j)

string = ''.join([x[0] for x in groupby(for_string)]).replace('$', '').strip()
string = "".join([x[0] for x in groupby(for_string)]).replace("$", "").strip()

align_list = list(filter(lambda x: x, align_list))

Expand All @@ -64,25 +62,26 @@ def process(self,
to_move = min(1.5, len(argm) - i)
align_word[-1] = align_word[-1] + to_move
else:
to_move = min(1.5, (align_list[i+1][0] - align_word[-1]) / 2)
to_move = min(1.5, (align_list[i + 1][0] - align_word[-1]) / 2)
align_word[-1] = align_word[-1] + to_move

for word, timing in zip(string.split(), align_list):
align_dicts.append({'word': word,
'start_ts': round(timing[0] * linear_align_coeff, 2),
'end_ts': round(timing[-1] * linear_align_coeff, 2)})
align_dicts.append(
{
"word": word,
"start_ts": round(timing[0] * linear_align_coeff, 2),
"end_ts": round(timing[-1] * linear_align_coeff, 2),
}
)

return string, align_dicts
return string

def __call__(self,
probs: torch.Tensor,
wav_len: float = 0,
word_align: bool = False):
def __call__(self, probs: torch.Tensor, wav_len: float = 0, word_align: bool = False):
return self.process(probs, wav_len, word_align)


def init_jit_model(model_path: str, device: torch.device = torch.device('cpu')):
def init_jit_model(model_path: str, device: torch.device = torch.device("cpu")):
torch.set_grad_enabled(False)
model = torch.jit.load(model_path, map_location=device)
model.eval()
Expand Down
5 changes: 4 additions & 1 deletion dataset/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(self, language="en"):
self.model, self.decoder = init_jit_model(model, self.device)
else:
self.model, self.decoder, _ = torch.hub.load(
repo_or_dir="snakers4/silero-models", model="silero_stt", language=language, device=self.device, source="github"
repo_or_dir="snakers4/silero-models",
model="silero_stt",
language=language,
device=self.device,
)

def load_audio(self, path):
Expand Down

0 comments on commit 4f405b3

Please sign in to comment.