Skip to content

Commit

Permalink
modified ctc segmentation
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed May 3, 2024
1 parent 215a44d commit 51cd688
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
"--max_duration",
type=int,
help="Maximum audio duration (seconds). Samples that are longer will be dropped",
default=60,
default=40,
)
parser.add_argument(
"--max_silence",
type=float,
help="Maximum silence duration to combain segments, s",
default=0.5,
default=1.5,
)

@dataclass
Expand Down Expand Up @@ -108,7 +108,7 @@ def process_alignment(alignment_file: str, manifest: str, clips_dir: str, args):

# set distribution
population = [1, 2]
weights = [0.85, 0.15]
weights = [0.80, 0.20]

with open(manifest, "a", encoding="utf8") as f:
new_segment = None
Expand All @@ -128,6 +128,7 @@ def process_alignment(alignment_file: str, manifest: str, clips_dir: str, args):
text_normalized=ref_text_normalized[i].strip(),)
else:
do_merge = choices(population, weights)[0] == 1
# do_merge = False
if st - new_segment.end_time < args.max_silence and end - new_segment.start_time < args.max_duration and do_merge:
new_segment.end_time = end
new_segment.text_processed += f" {ref_text_processed[i].strip()}"
Expand All @@ -146,8 +147,8 @@ def process_alignment(alignment_file: str, manifest: str, clips_dir: str, args):
"audio_filepath": audio_filepath,
"duration": duration,
"text": new_segment.text_processed,
"text_no_preprocessing": new_segment.text_no_preprocessing,
"text_normalized": new_segment.text_normalized,
"text_pc": new_segment.text_normalized,
"text_origin": new_segment.text_no_preprocessing,
"start_abs": float(np.mean(np.abs(segment_samples[:num_samples]))),
"end_abs": float(np.mean(np.abs(segment_samples[-num_samples:]))),
}
Expand Down
11 changes: 8 additions & 3 deletions tools/ctc_segmentation/scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import re
from glob import glob
from typing import List, Optional
import logging

import regex
from joblib import Parallel, delayed
Expand Down Expand Up @@ -242,11 +243,13 @@ def _split(sentences, delimiter):

vocabulary_symbols = []
for x in vocabulary:
if x != "<unk>":
if x != "<unk>" and x not in ".,?!":
# for BPE models
vocabulary_symbols.extend([x for x in x.replace("##", "").replace("▁", "")])
vocabulary_symbols = list(set(vocabulary_symbols))
vocabulary_symbols += [x.upper() for x in vocabulary_symbols]
# logging.warning(f"{vocabulary}")
# raise ValueError("Stop here")

# check to make sure there will be no utterances for segmentation with only OOV symbols
vocab_no_space_with_digits = set(vocabulary_symbols + [str(i) for i in range(10)])
Expand Down Expand Up @@ -355,8 +358,10 @@ def _split(sentences, delimiter):
model_name = args.model


vocabulary = asr_model.cfg.decoder.vocabulary
#vocabulary = asr_model.tokenizer.tokenizer.get_vocab()
if not args.model == "stt_en_fastconformer_hybrid_large_pc":
vocabulary = asr_model.cfg.decoder.vocabulary
else:
vocabulary = asr_model.tokenizer.tokenizer.get_vocab()

if os.path.isdir(args.in_text):
text_files = glob(f"{args.in_text}/*.txt")
Expand Down
37 changes: 29 additions & 8 deletions tools/ctc_segmentation/scripts/run_ctc_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,25 @@
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False)
else:
try:
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model)
#asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.model)
if not args.model == "stt_en_fastconformer_hybrid_large_pc":
if args.model == "stt_en_fastconformer_ctc_large":
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.model)
asr_model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64])
else:
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model)
else:
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.model)
asr_model.change_decoding_strategy(decoder_type="ctc")
asr_model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64])
except:
raise ValueError(
f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}"
)

bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE)
if not args.model == "stt_en_fastconformer_hybrid_large_pc":
bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE)
else:
bpe_model = isinstance(asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel)

# get tokenizer used during training, None for char based models
if bpe_model:
Expand All @@ -93,8 +104,10 @@
tokenizer = None

# extract ASR vocabulary and add blank symbol
vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary)
#vocabulary = ["ε"] + list(asr_model.tokenizer.tokenizer.get_vocab())
if not args.model == "stt_en_fastconformer_hybrid_large_pc":
vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary)
else:
vocabulary = ["ε"] + list(asr_model.tokenizer.tokenizer.get_vocab())
logging.debug(f"ASR Model vocabulary: {vocabulary}")

data = Path(args.data)
Expand Down Expand Up @@ -138,9 +151,16 @@
logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}")
logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}")

log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[
0
].alignments
if not args.model == "stt_en_fastconformer_hybrid_large_pc":
log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[
0
].alignments
else:
log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[0][0].alignments

# logging.warning(f"************************************")
# logging.warning(f"log_probs: {log_probs.shape}")
# raise ValueError("Stop here")
# move blank values to the first column (ctc-package compatibility)
blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1))
log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1)
Expand Down Expand Up @@ -178,6 +198,7 @@
args.window_len,
log_file=log_file,
debug=args.debug,
use_pc_model=args.model == "stt_en_fastconformer_hybrid_large_pc",
)
for i in tqdm(range(len(all_log_probs)))
)
Expand Down
38 changes: 29 additions & 9 deletions tools/ctc_segmentation/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_segments(
window_size: int = 8000,
log_file: str = "log.log",
debug: bool = False,
use_pc_model: bool = True,
) -> None:
"""
Segments the audio into segments and saves segments timings to a file
Expand All @@ -62,24 +63,37 @@ def get_segments(
logging.basicConfig(handlers=handlers, level=level)

try:
with open(transcript_file, "r") as f:
text = f.readlines()
text = [t.strip() for t in text if t.strip()]

# with open(transcript_file, "r") as f:
# text = f.readlines()
# text = [t.strip() for t in text if t.strip()]


# add corresponding normalized original text
transcript_file_normalized = transcript_file.replace(".txt", ".pc.txt")
if not os.path.exists(transcript_file_normalized):
raise ValueError(f"{transcript_file_normalized} not found.")

# add corresponding original text without pre-processing
transcript_file_no_preprocessing = transcript_file.replace(".txt", "_with_punct.txt")
transcript_file_no_preprocessing = transcript_file.replace(".txt", ".origin.txt")
if not os.path.exists(transcript_file_no_preprocessing):
raise ValueError(f"{transcript_file_no_preprocessing} not found.")

# if we are using a PC model, we need to use the original transcript file
if not use_pc_model:
main_transcript_file = transcript_file
else:
# raise ValueError("PC model not supported.")
main_transcript_file = transcript_file_normalized
with open(main_transcript_file, "r") as f:
text = f.readlines()
text = [t.strip() for t in text if t.strip()]


with open(transcript_file_no_preprocessing, "r") as f:
text_no_preprocessing = f.readlines()
text_no_preprocessing = [t.strip() for t in text_no_preprocessing if t.strip()]

# add corresponding normalized original text
transcript_file_normalized = transcript_file.replace(".txt", "_with_punct_normalized.txt")
if not os.path.exists(transcript_file_normalized):
raise ValueError(f"{transcript_file_normalized} not found.")

with open(transcript_file_normalized, "r") as f:
text_normalized = f.readlines()
text_normalized = [t.strip() for t in text_normalized if t.strip()]
Expand Down Expand Up @@ -114,6 +128,12 @@ def get_segments(

timings, char_probs, char_list = cs.ctc_segmentation(config, log_probs, ground_truth_mat)
_print(ground_truth_mat, vocabulary)

if use_pc_model:
with open(transcript_file, "r") as f:
text = f.readlines()
text = [t.strip() for t in text if t.strip()]

segments = determine_utterance_segments(config, utt_begin_indices, char_probs, timings, text, char_list)

write_output(output_file, path_wav, segments, text, text_no_preprocessing, text_normalized)
Expand Down

0 comments on commit 51cd688

Please sign in to comment.