diff --git a/tools/nemo_forced_aligner/README.md b/tools/nemo_forced_aligner/README.md new file mode 100644 index 000000000000..1f96eba98887 --- /dev/null +++ b/tools/nemo_forced_aligner/README.md @@ -0,0 +1,84 @@ +# NeMo Forced Aligner (NFA) + +A tool for doing Forced Alignment using Viterbi decoding of NeMo CTC-based models. + +## Usage example + +``` bash +python /tools/nemo_forced_aligner/align.py \ + pretrained_name="stt_en_citrinet_1024_gamma_0_25" \ + model_downsample_factor=8 \ + manifest_filepath= \ + output_dir= +``` + +## How do I use NeMo Forced Aligner? +To use NFA, all you need to provide is a correct NeMo manifest (with `"audio_filepath"` and `"text"` fields). + +Call the `align.py` script, specifying the parameters as follows: + +* `pretrained_name`: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded from NGC and used for generating the log-probs which we will use to do alignment. Any Quartznet, Citrinet, Conformer CTC model should work, in any language (only English has been tested so far). If `model_path` is specified, `pretrained_name` must not be specified. +>Note: NFA can only use CTC models (not Transducer models) at the moment. If you want to transcribe a long audio file (longer than ~5-10 mins), do not use Conformer CTC model as that will likely give Out Of Memory errors. + +* `model_path`: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the log-probs which we will use to do alignment. If `pretrained_name` is specified, `model_path` must not be specified. +>Note: NFA can only use CTC models (not Transducer models) at the moment. If you want to transcribe a long audio file (longer than ~5-10 mins), do not use Conformer CTC model as that will likely give Out Of Memory errors. + +* `model_downsample_factor`: the downsample factor of the ASR model. It should be 2 if your model is QuartzNet, 4 if it is Conformer CTC, 8 if it is Citrinet. + +* `manifest_filepath`: The path to the manifest of the data you want to align, containing `'audio_filepath'` and `'text'` fields. The audio filepaths need to be absolute paths. + +* `output_dir`: The folder where to save CTM files containing the generated alignments and new JSON manifest containing paths to those CTM files. There will be one CTM file per utterance (ie one CTM file per line in the manifest). The files will be called `/{tokens,words,additional_segments}/.ctm` and each line in each file will start with ``. By default, `utt_id` will be the stem of the audio_filepath. This can be changed by overriding `audio_filepath_parts_in_utt_id`. The new JSON manifest will be at `/_with_ctm_paths.json`. + +* **[OPTIONAL]** `align_using_pred_text`: if True, will transcribe the audio using the ASR model (specified by `pretrained_name` or `model_path`) and then use that transcription as the 'ground truth' for the forced alignment. The `"pred_text"` will be saved in the output JSON manifest at `/{original manifest name}_with_ctm_paths.json`. To avoid over-writing other transcribed texts, if there are already `"pred_text"` entries in the original manifest, the program will exit without attempting to generate alignments. (Default: False). + +* **[OPTIONAL]** `transcribe_device`: The device that will be used for generating log-probs (i.e. transcribing). If None, NFA will set it to 'cuda' if it is available (otherwise will set it to 'cpu'). If specified `transcribe_device` needs to be a string that can be input to the `torch.device()` method. (Default: `None`). + +* **[OPTIONAL]** `viterbi_device`: The device that will be used for doing Viterbi decoding. If None, NFA will set it to 'cuda' if it is available (otherwise will set it to 'cpu'). If specified `transcribe_device` needs to be a string that can be input to the `torch.device()` method.(Default: `None`). + +* **[OPTIONAL]** `batch_size`: The batch_size that will be used for generating log-probs and doing Viterbi decoding. (Default: 1). + +* **[OPTIONAL]** `additional_ctm_grouping_separator`: the string used to separate CTM segments if you want to obtain CTM files at a level that is not the token level or the word level. NFA will always produce token-level and word-level CTM files in: `/tokens/.ctm` and `/words/.ctm`. If `additional_ctm_grouping_separator` is specified, an additional folder `/{tokens/words/additional_segments}/.ctm` will be created containing CTMs for `addtional_ctm_grouping_separator`-separated segments. (Default: `None`. Cannot be empty string or space (" "), as space-separated word-level CTMs will always be saved in `/words/.ctm`.) +> Note: the `additional_ctm_grouping_separator` will be removed from the ground truth text and all the output CTMs, ie it is treated as a marker which is not part of the ground truth. The separator will essentially be treated as a space, and any additional spaces around it will be amalgamated into one, i.e. if `additional_ctm_grouping_separator="|"`, the following texts will be treated equivalently: `“abc|def”`, `“abc |def”`, `“abc| def”`, `“abc | def"`. + +* **[OPTIONAL]** `remove_blank_tokens_from_ctm`: a boolean denoting whether to remove tokens from token-level output CTMs. (Default: False). + +* **[OPTIONAL]** `audio_filepath_parts_in_utt_id`: This specifies how many of the 'parts' of the audio_filepath we will use (starting from the final part of the audio_filepath) to determine the utt_id that will be used in the CTM files. (Default: 1, i.e. utt_id will be the stem of the basename of audio_filepath). Note also that any spaces that are present in the audio_filepath will be replaced with dashes, so as not to change the number of space-separated elements in the CTM files. + +* **[OPTIONAL]** `minimum_timestamp_duration`: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. Note that this may cause timestamps to overlap. (Default: 0, i.e. no modifications to predicted duration). + +# Input manifest file format +By default, NFA needs to be provided with a 'manifest' file where each line specifies the absolute "audio_filepath" and "text" of each utterance that you wish to produce alignments for, like the format below: +```json +{"audio_filepath": "/absolute/path/to/audio.wav", "text": "the transcription of the utterance"} +``` + +You can omit the `"text"` field from the manifest if you specify `align_using_pred_text=true`. In that case, any `"text"` fields in the manifest will be ignored: the ASR model at `pretrained_name` or `model_path` will be used to transcribe the audio and obtain `"pred_text"`, which will be used as the 'ground truth' for the forced alignment process. The `"pred_text"` will also be saved in the output manifest JSON file at `/_with_ctm_paths.json`. To remove the possibility of overwriting `"pred_text"`, NFA will raise an error if `align_using_pred_text=true` and there are existing `"pred_text"` fields in the original manifest. + +> Note: NFA does not require `"duration"` fields in the manifest, and can align long audio files without running out of memory. Depending on your machine specs, you can align audios up to 5-10 minutes on Conformer CTC models, up to around 1.5 hours for QuartzNet models, and up to several hours for Citrinet models. NFA will also produce better alignments the more accurate the ground-truth `"text"` is. + + +# Output CTM file format +For each utterance specified in a line of `manifest_filepath`, several CTM files will be generated: +* a CTM file containing token-level alignments at `/tokens/.ctm`, +* a CTM file containing word-level alignments at `/words/.ctm`, +* if `additional_ctm_grouping_separator` is specified, there will also be a CTM file containing those segments at `output_dir/additional_segments`. +Each CTM file will contain lines of the format: +` 1 `. +Note the second item in the line (the 'channel ID', which is required by the CTM file format) is always 1, as NFA operates on single channel audio. + +# Output JSON manifest file format +A new manifest file will be saved at `/_with_ctm_paths.json`. It will contain the same fields as the original manifest, and additionally: +* `"token_level_ctm_filepath"` +* `"word_level_ctm_filepath"` +* `"additonal_segment_level_ctm_filepath"` (if `additional_ctm_grouping_separator` is specified) +* `"pred_text"` (if `align_using_pred_text=true`) + + +# How do I evaluate the alignment accuracy? +Ideally you would have some 'true' CTM files to compare with your generated CTM files. With these you could obtain metrics such as the mean (absolute) errors between predicted starts/ends and the 'true' starts/ends of the segments. + +Alternatively (or additionally), you can visualize the quality of alignments using tools such as Gecko, which can play your audio file and display the predicted alignments at the same time. The Gecko tool requires you to upload an audio file and at least one CTM file. The Gecko tool can be accessed here: https://gong-io.github.io/gecko/. More information about the Gecko tool can be found on its Github page here: https://github.com/gong-io/gecko. + +**Note**: the following may help improve your experience viewing the CTMs in Gecko: +* setting `minimum_timestamp_duration` to a larger number, as Gecko may not display some tokens/words/segments properly if their timestamps are too short. +* setting `remove_blank_tokens_from_ctm=true` if you are analyzing token-level CTMs, as it will make the Gecko visualization less cluttered. diff --git a/tools/nemo_forced_aligner/align.py b/tools/nemo_forced_aligner/align.py new file mode 100644 index 000000000000..5f2a781a381f --- /dev/null +++ b/tools/nemo_forced_aligner/align.py @@ -0,0 +1,287 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, is_dataclass +from typing import Optional + +import torch +from omegaconf import OmegaConf +from utils.data_prep import ( + get_audio_sr, + get_batch_starts_ends, + get_batch_tensors_and_boundary_info, + get_manifest_lines_batch, + is_entry_in_all_lines, + is_entry_in_any_lines, +) +from utils.make_output_files import make_ctm, make_new_manifest +from utils.viterbi_decoding import viterbi_decoding + +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.utils.transcribe_utils import setup_model +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +""" +Align the utterances in manifest_filepath. +Results are saved in ctm files in output_dir. + +Arguments: + pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded + from NGC and used for generating the log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the + log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + Note: if a model_path is provided, it will override the pretrained_name. + model_downsample_factor: an int indicating the downsample factor of the ASR model, ie the ratio of input + timesteps to output timesteps. + If the ASR model is a QuartzNet model, its downsample factor is 2. + If the ASR model is a Conformer CTC model, its downsample factor is 4. + If the ASR model is a Citirnet model, its downsample factor is 8. + manifest_filepath: filepath to the manifest of the data you want to align, + containing 'audio_filepath' and 'text' fields. + output_dir: the folder where output CTM files and new JSON manifest will be saved. + align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription + as the 'ground truth' for the forced alignment. + transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. + additional_ctm_grouping_separator: the string used to separate CTM segments if you want to obtain CTM files at a + level that is not the token level or the word level. NFA will always produce token-level and word-level CTM + files in: `/tokens/.ctm` and `/words/.ctm`. + If `additional_ctm_grouping_separator` is specified, an additional folder + `/{tokens/words/additional_segments}/.ctm` will be created containing CTMs + for `addtional_ctm_grouping_separator`-separated segments. + remove_blank_tokens_from_ctm: a boolean denoting whether to remove tokens from token-level output CTMs. + audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath + we will use (starting from the final part of the audio_filepath) to determine the + utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath + will be replaced with dashes, so as not to change the number of space-separated elements in the + CTM files. + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" + minimum_timestamp_duration: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any + line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the + middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio + file. Note that this may cause timestamps to overlap. +""" + + +@dataclass +class AlignmentConfig: + # Required configs + pretrained_name: Optional[str] = None + model_path: Optional[str] = None + model_downsample_factor: Optional[int] = None + manifest_filepath: Optional[str] = None + output_dir: Optional[str] = None + + # General configs + align_using_pred_text: bool = False + transcribe_device: Optional[str] = None + viterbi_device: Optional[str] = None + batch_size: int = 1 + additional_ctm_grouping_separator: Optional[str] = None + remove_blank_tokens_from_ctm: bool = False + minimum_timestamp_duration: float = 0 + audio_filepath_parts_in_utt_id: int = 1 + + +@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) +def main(cfg: AlignmentConfig): + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + # Validate config + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") + + if cfg.model_path is not None and cfg.pretrained_name is not None: + raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") + + if cfg.model_downsample_factor is None: + raise ValueError("cfg.model_downsample_factor must be specified") + + if cfg.manifest_filepath is None: + raise ValueError("cfg.manifest_filepath must be specified") + + if cfg.output_dir is None: + raise ValueError("cfg.output_dir must be specified") + + if cfg.batch_size < 1: + raise ValueError("cfg.batch_size cannot be zero or a negative number") + + if cfg.additional_ctm_grouping_separator == "" or cfg.additional_ctm_grouping_separator == " ": + raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") + + if cfg.minimum_timestamp_duration < 0: + raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") + + # Validate manifest contents + if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " + "All lines must contain an 'audio_filepath' entry." + ) + + if cfg.align_using_pred_text: + if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): + raise RuntimeError( + "Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " + "contains 'pred_text' entries. This is because the audio will be transcribed and may produce " + "a different 'pred_text'. This may cause confusion." + ) + else: + if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain a 'text' entry. " + "NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=True." + ) + + # init devices + if cfg.transcribe_device is None: + transcribe_device = torch.device("cuda" if torch.cuda.is_available else "cpu") + else: + transcribe_device = torch.device(cfg.transcribe_device) + logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") + + if cfg.viterbi_device is None: + viterbi_device = torch.device("cuda" if torch.cuda.is_available else "cpu") + else: + viterbi_device = torch.device(cfg.viterbi_device) + logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") + + if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': + logging.warning( + 'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' + 'it may help to change both devices to be the CPU.' + ) + + # load model + model, _ = setup_model(cfg, transcribe_device) + + if not isinstance(model, EncDecCTCModel): + raise NotImplementedError( + f"Model {cfg.model_name} is not an instance of NeMo EncDecCTCModel." + " Currently only instances of EncDecCTCModels are supported" + ) + + audio_sr = get_audio_sr(cfg.manifest_filepath) + logging.info( + f"Detected audio sampling rate {audio_sr}Hz in first audio in manifest at {cfg.manifest_filepath}. " + "Will assume all audios in manifest have this sampling rate. Sampling rate will be used to determine " + "timestamps in output CTM." + ) + + if cfg.minimum_timestamp_duration > 0: + logging.warning( + f"cfg.minimum_timestamp_duration has been set to {cfg.minimum_timestamp_duration} seconds. " + "This may cause the alignments for some tokens/words/additional segments to be overlapping." + ) + + # get start and end line IDs of batches + starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) + + if cfg.align_using_pred_text: + # record pred_texts to save them in the new manifest at the end of this script + pred_text_all_lines = [] + else: + pred_text_all_lines = None + + # get alignment and save in CTM batch-by-batch + for start, end in zip(starts, ends): + manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) + + ( + log_probs_batch, + y_batch, + T_batch, + U_batch, + token_info_batch, + word_info_batch, + segment_info_batch, + pred_text_batch, + ) = get_batch_tensors_and_boundary_info( + manifest_lines_batch, model, cfg.additional_ctm_grouping_separator, cfg.align_using_pred_text, + ) + + if cfg.align_using_pred_text: + pred_text_all_lines.extend(pred_text_batch) + + alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) + + make_ctm( + token_info_batch, + alignments_batch, + manifest_lines_batch, + model, + cfg.model_downsample_factor, + os.path.join(cfg.output_dir, "tokens"), + cfg.remove_blank_tokens_from_ctm, + cfg.audio_filepath_parts_in_utt_id, + cfg.minimum_timestamp_duration, + audio_sr, + ) + + make_ctm( + word_info_batch, + alignments_batch, + manifest_lines_batch, + model, + cfg.model_downsample_factor, + os.path.join(cfg.output_dir, "words"), + False, # dont try to remove blank tokens because we dont expect them to be there anyway + cfg.audio_filepath_parts_in_utt_id, + cfg.minimum_timestamp_duration, + audio_sr, + ) + + if cfg.additional_ctm_grouping_separator: + make_ctm( + segment_info_batch, + alignments_batch, + manifest_lines_batch, + model, + cfg.model_downsample_factor, + os.path.join(cfg.output_dir, "additional_segments"), + False, # dont try to remove blank tokens because we dont expect them to be there anyway + cfg.audio_filepath_parts_in_utt_id, + cfg.minimum_timestamp_duration, + audio_sr, + ) + + make_new_manifest( + cfg.output_dir, + cfg.manifest_filepath, + cfg.additional_ctm_grouping_separator, + cfg.audio_filepath_parts_in_utt_id, + pred_text_all_lines, + ) + + return None + + +if __name__ == "__main__": + main() diff --git a/tools/nemo_forced_aligner/requirements.txt b/tools/nemo_forced_aligner/requirements.txt new file mode 100644 index 000000000000..3af8ebf1b488 --- /dev/null +++ b/tools/nemo_forced_aligner/requirements.txt @@ -0,0 +1,2 @@ +nemo_toolkit[all] +pytest diff --git a/tools/nemo_forced_aligner/tests/test_add_t_start_end_to_boundary_info.py b/tools/nemo_forced_aligner/tests/test_add_t_start_end_to_boundary_info.py new file mode 100644 index 000000000000..406c4be1fb70 --- /dev/null +++ b/tools/nemo_forced_aligner/tests/test_add_t_start_end_to_boundary_info.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from utils.make_output_files import add_t_start_end_to_boundary_info + +ALIGNMENT = [ + 1, + 1, + 3, + 3, + 4, + 5, + 7, + 7, + 9, + 10, + 11, + 12, + 13, + 15, + 17, + 17, + 19, + 21, + 23, + 23, +] + +INPUT_TOKEN_INFO = [ + {'text': '', 's_start': 0, 's_end': 0}, + {'text': 'h', 's_start': 1, 's_end': 1}, + {'text': '', 's_start': 2, 's_end': 2}, + {'text': 'i', 's_start': 3, 's_end': 3}, + {'text': '', 's_start': 4, 's_end': 4}, + {'text': '', 's_start': 5, 's_end': 5}, + {'text': '', 's_start': 6, 's_end': 6}, + {'text': 'w', 's_start': 7, 's_end': 7}, + {'text': '', 's_start': 8, 's_end': 8}, + {'text': 'o', 's_start': 9, 's_end': 9}, + {'text': '', 's_start': 10, 's_end': 10}, + {'text': 'r', 's_start': 11, 's_end': 11}, + {'text': '', 's_start': 12, 's_end': 12}, + {'text': 'l', 's_start': 13, 's_end': 13}, + {'text': '', 's_start': 14, 's_end': 14}, + {'text': 'd', 's_start': 15, 's_end': 15}, + {'text': '', 's_start': 16, 's_end': 16}, + {'text': '', 's_start': 17, 's_end': 17}, + {'text': '', 's_start': 18, 's_end': 18}, + {'text': 'h', 's_start': 19, 's_end': 19}, + {'text': '', 's_start': 20, 's_end': 20}, + {'text': 'e', 's_start': 21, 's_end': 21}, + {'text': '', 's_start': 22, 's_end': 22}, + {'text': 'y', 's_start': 23, 's_end': 23}, + {'text': '', 's_start': 24, 's_end': 24}, +] + +EXPECTED_OUTPUT_TOKEN_INFO = [ + {'text': 'h', 's_start': 1, 's_end': 1, 't_start': 0, 't_end': 1}, + {'text': 'i', 's_start': 3, 's_end': 3, 't_start': 2, 't_end': 3}, + {'text': '', 's_start': 4, 's_end': 4, 't_start': 4, 't_end': 4}, + {'text': '', 's_start': 5, 's_end': 5, 't_start': 5, 't_end': 5}, + {'text': 'w', 's_start': 7, 's_end': 7, 't_start': 6, 't_end': 7}, + {'text': 'o', 's_start': 9, 's_end': 9, 't_start': 8, 't_end': 8}, + {'text': '', 's_start': 10, 's_end': 10, 't_start': 9, 't_end': 9}, + {'text': 'r', 's_start': 11, 's_end': 11, 't_start': 10, 't_end': 10}, + {'text': '', 's_start': 12, 's_end': 12, 't_start': 11, 't_end': 11}, + {'text': 'l', 's_start': 13, 's_end': 13, 't_start': 12, 't_end': 12}, + {'text': 'd', 's_start': 15, 's_end': 15, 't_start': 13, 't_end': 13}, + {'text': '', 's_start': 17, 's_end': 17, 't_start': 14, 't_end': 15}, + {'text': 'h', 's_start': 19, 's_end': 19, 't_start': 16, 't_end': 16}, + {'text': 'e', 's_start': 21, 's_end': 21, 't_start': 17, 't_end': 17}, + {'text': 'y', 's_start': 23, 's_end': 23, 't_start': 18, 't_end': 19}, +] + + +INPUT_WORD_INFO = [ + {'text': 'hi', 's_start': 1, 's_end': 3}, + {'text': 'world', 's_start': 7, 's_end': 15}, + {'text': 'hey', 's_start': 19, 's_end': 23}, +] + +EXPECTED_OUTPUT_WORD_INFO = [ + {'text': 'hi', 's_start': 1, 's_end': 3, 't_start': 0, 't_end': 3}, + {'text': 'world', 's_start': 7, 's_end': 15, 't_start': 6, 't_end': 13}, + {'text': 'hey', 's_start': 19, 's_end': 23, 't_start': 16, 't_end': 19}, +] + +INPUT_SEGMENT_INFO = [ + {'text': 'hi world', 's_start': 1, 's_end': 15}, + {'text': 'hey', 's_start': 19, 's_end': 23}, +] + +EXPECTED_OUTPUT_SEGMENT_INFO = [ + {'text': 'hi world', 's_start': 1, 's_end': 15, 't_start': 0, 't_end': 13}, + {'text': 'hey', 's_start': 19, 's_end': 23, 't_start': 16, 't_end': 19}, +] + + +@pytest.mark.parametrize( + "input_boundary_info_utt,alignment_utt,expected_output_boundary_info_utt", + [ + (INPUT_TOKEN_INFO, ALIGNMENT, EXPECTED_OUTPUT_TOKEN_INFO), + (INPUT_WORD_INFO, ALIGNMENT, EXPECTED_OUTPUT_WORD_INFO), + (INPUT_SEGMENT_INFO, ALIGNMENT, EXPECTED_OUTPUT_SEGMENT_INFO), + ], +) +def test_add_t_start_end_to_boundary_info(input_boundary_info_utt, alignment_utt, expected_output_boundary_info_utt): + output_boundary_info_utt = add_t_start_end_to_boundary_info(input_boundary_info_utt, alignment_utt) + assert output_boundary_info_utt == expected_output_boundary_info_utt diff --git a/tools/nemo_forced_aligner/tests/test_get_y_and_boundary_info_for_utt.py b/tools/nemo_forced_aligner/tests/test_get_y_and_boundary_info_for_utt.py new file mode 100644 index 000000000000..f5bc722d5a1c --- /dev/null +++ b/tools/nemo_forced_aligner/tests/test_get_y_and_boundary_info_for_utt.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from utils.data_prep import get_y_and_boundary_info_for_utt + +from nemo.collections.asr.models import ASRModel + +EN_TEXT = "hi world | hey" + +EN_QN_EXPECTED_TOKEN_INFO = [ + {'text': '', 's_start': 0, 's_end': 0}, + {'text': 'h', 's_start': 1, 's_end': 1}, + {'text': '', 's_start': 2, 's_end': 2}, + {'text': 'i', 's_start': 3, 's_end': 3}, + {'text': '', 's_start': 4, 's_end': 4}, + {'text': '', 's_start': 5, 's_end': 5}, + {'text': '', 's_start': 6, 's_end': 6}, + {'text': 'w', 's_start': 7, 's_end': 7}, + {'text': '', 's_start': 8, 's_end': 8}, + {'text': 'o', 's_start': 9, 's_end': 9}, + {'text': '', 's_start': 10, 's_end': 10}, + {'text': 'r', 's_start': 11, 's_end': 11}, + {'text': '', 's_start': 12, 's_end': 12}, + {'text': 'l', 's_start': 13, 's_end': 13}, + {'text': '', 's_start': 14, 's_end': 14}, + {'text': 'd', 's_start': 15, 's_end': 15}, + {'text': '', 's_start': 16, 's_end': 16}, + {'text': '', 's_start': 17, 's_end': 17}, + {'text': '', 's_start': 18, 's_end': 18}, + {'text': 'h', 's_start': 19, 's_end': 19}, + {'text': '', 's_start': 20, 's_end': 20}, + {'text': 'e', 's_start': 21, 's_end': 21}, + {'text': '', 's_start': 22, 's_end': 22}, + {'text': 'y', 's_start': 23, 's_end': 23}, + {'text': '', 's_start': 24, 's_end': 24}, +] + +EN_QN_EXPECTED_WORD_INFO = [ + {'text': 'hi', 's_start': 1, 's_end': 3}, + {'text': 'world', 's_start': 7, 's_end': 15}, + {'text': 'hey', 's_start': 19, 's_end': 23}, +] + +EN_QN_EXPECTED_SEGMENT_INFO = [ + {'text': 'hi world', 's_start': 1, 's_end': 15}, + {'text': 'hey', 's_start': 19, 's_end': 23}, +] + +EN_CN_EXPECTED_TOKEN_INFO = [ + {'text': '', 's_start': 0, 's_end': 0}, + {'text': '▁hi', 's_start': 1, 's_end': 1}, + {'text': '', 's_start': 2, 's_end': 2}, + {'text': '▁world', 's_start': 3, 's_end': 3}, + {'text': '', 's_start': 4, 's_end': 4}, + {'text': '▁he', 's_start': 5, 's_end': 5}, + {'text': '', 's_start': 6, 's_end': 6}, + {'text': 'y', 's_start': 7, 's_end': 7}, + {'text': '', 's_start': 8, 's_end': 8}, +] + +EN_CN_EXPECTED_WORD_INFO = [ + {'text': 'hi', 's_start': 1, 's_end': 1}, + {'text': 'world', 's_start': 3, 's_end': 3}, + {'text': 'hey', 's_start': 5, 's_end': 7}, +] + +EN_CN_EXPECTED_SEGMENT_INFO = [ + {'text': 'hi world', 's_start': 1, 's_end': 3}, + {'text': 'hey', 's_start': 5, 's_end': 7}, +] + + +ZH_TEXT = "人工 智能|技术" + +ZH_EXPECTED_TOKEN_INFO = [ + {'text': '', 's_start': 0, 's_end': 0}, + {'text': '人', 's_start': 1, 's_end': 1}, + {'text': '', 's_start': 2, 's_end': 2}, + {'text': '工', 's_start': 3, 's_end': 3}, + {'text': '', 's_start': 4, 's_end': 4}, + {'text': '', 's_start': 5, 's_end': 5}, + {'text': '', 's_start': 6, 's_end': 6}, + {'text': '智', 's_start': 7, 's_end': 7}, + {'text': '', 's_start': 8, 's_end': 8}, + {'text': '能', 's_start': 9, 's_end': 9}, + {'text': '', 's_start': 10, 's_end': 10}, + {'text': '', 's_start': 11, 's_end': 11}, + {'text': '', 's_start': 12, 's_end': 12}, + {'text': '技', 's_start': 13, 's_end': 13}, + {'text': '', 's_start': 14, 's_end': 14}, + {'text': '术', 's_start': 15, 's_end': 15}, + {'text': '', 's_start': 16, 's_end': 16}, +] + +ZH_EXPECTED_WORD_INFO = [ + {'text': '人工', 's_start': 1, 's_end': 3}, + {'text': '智能', 's_start': 7, 's_end': 9}, + {'text': '技术', 's_start': 13, 's_end': 15}, +] + +ZH_EXPECTED_SEGMENT_INFO = [ + {'text': '人工 智能', 's_start': 1, 's_end': 9}, + {'text': '技术', 's_start': 13, 's_end': 15}, +] + + +@pytest.mark.parametrize( + "text,model_pretrained_name,separator,expected_token_info", + [ + (EN_TEXT, "stt_en_quartznet15x5", "|", EN_QN_EXPECTED_TOKEN_INFO), + (EN_TEXT, "stt_en_citrinet_256_gamma_0_25", "|", EN_CN_EXPECTED_TOKEN_INFO), + (ZH_TEXT, "stt_zh_citrinet_512", "|", ZH_EXPECTED_TOKEN_INFO), + ], +) +def test_token_info(text, model_pretrained_name, separator, expected_token_info): + model = ASRModel.from_pretrained(model_pretrained_name) + _, token_info, *_ = get_y_and_boundary_info_for_utt(text, model, separator) + assert token_info == expected_token_info + + +@pytest.mark.parametrize( + "text,model_pretrained_name,separator,expected_word_info", + [ + (EN_TEXT, "stt_en_quartznet15x5", "|", EN_QN_EXPECTED_WORD_INFO), + (EN_TEXT, "stt_en_citrinet_256_gamma_0_25", "|", EN_CN_EXPECTED_WORD_INFO), + (ZH_TEXT, "stt_zh_citrinet_512", "|", ZH_EXPECTED_WORD_INFO), + ], +) +def test_word_info(text, model_pretrained_name, separator, expected_word_info): + model = ASRModel.from_pretrained(model_pretrained_name) + _, _, word_info, _ = get_y_and_boundary_info_for_utt(text, model, separator) + assert word_info == expected_word_info + + +@pytest.mark.parametrize( + "text,model_pretrained_name,separator,expected_segment_info", + [ + (EN_TEXT, "stt_en_quartznet15x5", "|", EN_QN_EXPECTED_SEGMENT_INFO), + (EN_TEXT, "stt_en_citrinet_256_gamma_0_25", "|", EN_CN_EXPECTED_SEGMENT_INFO), + (ZH_TEXT, "stt_zh_citrinet_512", "|", ZH_EXPECTED_SEGMENT_INFO), + ], +) +def test_segment_info(text, model_pretrained_name, separator, expected_segment_info): + model = ASRModel.from_pretrained(model_pretrained_name) + *_, segment_info = get_y_and_boundary_info_for_utt(text, model, separator) + assert segment_info == expected_segment_info diff --git a/tools/nemo_forced_aligner/utils/constants.py b/tools/nemo_forced_aligner/utils/constants.py new file mode 100644 index 000000000000..894f880401cb --- /dev/null +++ b/tools/nemo_forced_aligner/utils/constants.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +BLANK_TOKEN = "" + +SPACE_TOKEN = "" + +V_NEGATIVE_NUM = -1e30 diff --git a/tools/nemo_forced_aligner/utils/data_prep.py b/tools/nemo_forced_aligner/utils/data_prep.py new file mode 100644 index 000000000000..26d8a328b50d --- /dev/null +++ b/tools/nemo_forced_aligner/utils/data_prep.py @@ -0,0 +1,385 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import soundfile as sf +import torch +from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM + + +def get_batch_starts_ends(manifest_filepath, batch_size): + """ + Get the start and end ids of the lines we will use for each 'batch'. + """ + + with open(manifest_filepath, 'r') as f: + num_lines_in_manifest = sum(1 for _ in f) + + starts = [x for x in range(0, num_lines_in_manifest, batch_size)] + ends = [x - 1 for x in starts] + ends.pop(0) + ends.append(num_lines_in_manifest) + + return starts, ends + + +def is_entry_in_any_lines(manifest_filepath, entry): + """ + Returns True if entry is a key in any of the JSON lines in manifest_filepath + """ + + entry_in_manifest = False + + with open(manifest_filepath, 'r') as f: + for line in f: + data = json.loads(line) + + if entry in data: + entry_in_manifest = True + + return entry_in_manifest + + +def is_entry_in_all_lines(manifest_filepath, entry): + """ + Returns True is entry is a key in all of the JSON lines in manifest_filepath. + """ + with open(manifest_filepath, 'r') as f: + for line in f: + data = json.loads(line) + + if entry not in data: + return False + + return True + + +def get_audio_sr(manifest_filepath): + """ + Measure the sampling rate of the audio file in the first line + of the manifest at manifest_filepath + """ + with open(manifest_filepath, "r") as f_manifest: + first_line = json.loads(f_manifest.readline()) + + audio_file = first_line["audio_filepath"] + if not os.path.exists(audio_file): + raise RuntimeError(f"Did not find filepath {audio_file} which was specified in manifest {manifest_filepath}.") + + with sf.SoundFile(audio_file, "r") as f_audio: + return f_audio.samplerate + + +def get_manifest_lines_batch(manifest_filepath, start, end): + manifest_lines_batch = [] + with open(manifest_filepath, "r") as f: + for line_i, line in enumerate(f): + if line_i == start and line_i == end: + manifest_lines_batch.append(json.loads(line)) + break + + if line_i == end: + break + if line_i >= start: + manifest_lines_batch.append(json.loads(line)) + return manifest_lines_batch + + +def get_char_tokens(text, model): + tokens = [] + for character in text: + if character in model.decoder.vocabulary: + tokens.append(model.decoder.vocabulary.index(character)) + else: + tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token) + + return tokens + + +def get_y_and_boundary_info_for_utt(text, model, separator): + """ + Get y_token_ids_with_blanks, token_info, word_info and segment_info for the text provided, tokenized + by the model provided. + y_token_ids_with_blanks is a list of the indices of the text tokens with the blank token id in between every + text token. + token_info, word_info and segment_info are lists of dictionaries containing information about + where the tokens/words/segments start and end. + For example, 'hi world | hey ' with separator = '|' and tokenized by a BPE tokenizer can have token_info like: + token_info = [ + {'text': '', 's_start': 0, 's_end': 0}, + {'text': '▁hi', 's_start': 1, 's_end': 1}, + {'text': '', 's_start': 2, 's_end': 2}, + {'text': '▁world', 's_start': 3, 's_end': 3}, + {'text': '', 's_start': 4, 's_end': 4}, + {'text': '▁he', 's_start': 5, 's_end': 5}, + {'text': '', 's_start': 6, 's_end': 6}, + {'text': 'y', 's_start': 7, 's_end': 7}, + {'text': '', 's_start': 8, 's_end': 8}, + ] + 's_start' and 's_end' indicate where in the sequence of tokens does each token start and end. + + The word_info will be as follows: + word_info = [ + {'text': 'hi', 's_start': 1, 's_end': 1}, + {'text': 'world', 's_start': 3, 's_end': 3}, + {'text': 'hey', 's_start': 5, 's_end': 7}, + ] + 's_start' and 's_end' indicate where in the sequence of tokens does each word start and end. + + segment_info will be as follows: + segment_info = [ + {'text': 'hi world', 's_start': 1, 's_end': 3}, + {'text': 'hey', 's_start': 5, 's_end': 7}, + ] + 's_start' and 's_end' indicate where in the sequence of tokens does each segment start and end. + """ + + if not separator: # if separator is not defined - treat the whole text as one segment + segments = [text] + else: + segments = text.split(separator) + + # remove any spaces at start and end of segments + segments = [seg.strip() for seg in segments] + + if hasattr(model, 'tokenizer'): + + BLANK_ID = len(model.decoder.vocabulary) # TODO: check + + y_token_ids_with_blanks = [BLANK_ID] + token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}] + word_info = [] + segment_info = [] + + segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank + word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank + + for segment in segments: + words = segment.split(" ") # we define words to be space-separated sub-strings + for word in words: + + word_tokens = model.tokenizer.text_to_tokens(word) + word_ids = model.tokenizer.text_to_ids(word) + for token, id_ in zip(word_tokens, word_ids): + # add the text token and the blank that follows it + # to our token-based variables + y_token_ids_with_blanks.extend([id_, BLANK_ID]) + token_info.extend( + [ + { + "text": token, + "s_start": len(y_token_ids_with_blanks) - 2, + "s_end": len(y_token_ids_with_blanks) - 2, + }, + { + "text": BLANK_TOKEN, + "s_start": len(y_token_ids_with_blanks) - 1, + "s_end": len(y_token_ids_with_blanks) - 1, + }, + ] + ) + + # add the word to word_info and increment the word_s_pointer + word_info.append( + { + "text": word, + "s_start": word_s_pointer, + "s_end": word_s_pointer + (len(word_tokens) - 1) * 2, # TODO check this, + } + ) + word_s_pointer += len(word_tokens) * 2 # TODO check this + + # add the segment to segment_info and increment the segment_s_pointer + segment_tokens = model.tokenizer.text_to_tokens(segment) + segment_info.append( + { + "text": segment, + "s_start": segment_s_pointer, + "s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2, + } + ) + segment_s_pointer += len(segment_tokens) * 2 + + return y_token_ids_with_blanks, token_info, word_info, segment_info + + elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based + + BLANK_ID = len(model.decoder.vocabulary) # TODO: check this is correct + SPACE_ID = model.decoder.vocabulary.index(" ") + + y_token_ids_with_blanks = [BLANK_ID] + token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}] + word_info = [] + segment_info = [] + + segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank + word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank + + for i_segment, segment in enumerate(segments): + words = segment.split(" ") # we define words to be space-separated characters + for i_word, word in enumerate(words): + + # convert string to list of characters + word_tokens = list(word) + # convert list of characters to list of their ids in the vocabulary + word_ids = get_char_tokens(word, model) + for token, id_ in zip(word_tokens, word_ids): + # add the text token and the blank that follows it + # to our token-based variables + y_token_ids_with_blanks.extend([id_, BLANK_ID]) + token_info.extend( + [ + { + "text": token, + "s_start": len(y_token_ids_with_blanks) - 2, + "s_end": len(y_token_ids_with_blanks) - 2, + }, + { + "text": BLANK_TOKEN, + "s_start": len(y_token_ids_with_blanks) - 1, + "s_end": len(y_token_ids_with_blanks) - 1, + }, + ] + ) + + # add space token (and the blank after it) unless this is the final word in the final segment + if not (i_segment == len(segments) - 1 and i_word == len(words) - 1): + y_token_ids_with_blanks.extend([SPACE_ID, BLANK_ID]) + token_info.extend( + ( + { + "text": SPACE_TOKEN, + "s_start": len(y_token_ids_with_blanks) - 2, + "s_end": len(y_token_ids_with_blanks) - 2, + }, + { + "text": BLANK_TOKEN, + "s_start": len(y_token_ids_with_blanks) - 1, + "s_end": len(y_token_ids_with_blanks) - 1, + }, + ) + ) + # add the word to word_info and increment the word_s_pointer + word_info.append( + { + "text": word, + "s_start": word_s_pointer, + "s_end": word_s_pointer + len(word_tokens) * 2 - 2, # TODO check this, + } + ) + word_s_pointer += len(word_tokens) * 2 + 2 # TODO check this + + # add the segment to segment_info and increment the segment_s_pointer + segment_tokens = get_char_tokens(segment, model) + segment_info.append( + { + "text": segment, + "s_start": segment_s_pointer, + "s_end": segment_s_pointer + (len(segment_tokens) - 1) * 2, + } + ) + segment_s_pointer += len(segment_tokens) * 2 + 2 + + return y_token_ids_with_blanks, token_info, word_info, segment_info + + else: + raise RuntimeError("Cannot get tokens of this model.") + + +def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, align_using_pred_text): + """ + Returns: + log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need + during Viterbi decoding. + token_info_list, word_info_list, segment_info_list - these are lists of dictionaries which we will need + for writing the CTM files with the human-readable alignments. + pred_text_list - this is a list of the transcriptions from our model which we will save to our output JSON + file if align_using_pred_text is True. + """ + + # get hypotheses by calling 'transcribe' + # we will use the output log_probs, the duration of the log_probs, + # and (optionally) the predicted ASR text from the hypotheses + audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch] + B = len(audio_filepaths_batch) + with torch.no_grad(): + hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B) + + log_probs_list_batch = [] + T_list_batch = [] + pred_text_batch = [] + for hypothesis in hypotheses: + log_probs_list_batch.append(hypothesis.y_sequence) + T_list_batch.append(hypothesis.y_sequence.shape[0]) + pred_text_batch.append(hypothesis.text) + + # we loop over every line in the manifest that is in our current batch, + # and record the y (list of tokens, including blanks), U (list of lengths of y) and + # token_info_batch, word_info_batch, segment_info_batch + y_list_batch = [] + U_list_batch = [] + token_info_batch = [] + word_info_batch = [] + segment_info_batch = [] + + for i_line, line in enumerate(manifest_lines_batch): + if align_using_pred_text: + gt_text_for_alignment = pred_text_batch[i_line] + else: + gt_text_for_alignment = line["text"] + y_utt, token_info_utt, word_info_utt, segment_info_utt = get_y_and_boundary_info_for_utt( + gt_text_for_alignment, model, separator + ) + + y_list_batch.append(y_utt) + U_list_batch.append(len(y_utt)) + token_info_batch.append(token_info_utt) + word_info_batch.append(word_info_utt) + segment_info_batch.append(segment_info_utt) + + # turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding + T_max = max(T_list_batch) + U_max = max(U_list_batch) + # V = the number of tokens in the vocabulary + 1 for the blank token. + V = len(model.decoder.vocabulary) + 1 + T_batch = torch.tensor(T_list_batch) + U_batch = torch.tensor(U_list_batch) + + # make log_probs_batch tensor of shape (B x T_max x V) + log_probs_batch = V_NEGATIVE_NUM * torch.ones((B, T_max, V)) + for b, log_probs_utt in enumerate(log_probs_list_batch): + t = log_probs_utt.shape[0] + log_probs_batch[b, :t, :] = log_probs_utt + + # make y tensor of shape (B x U_max) + # populate it initially with all 'V' numbers so that the 'V's will remain in the areas that + # are 'padding'. This will be useful for when we make 'log_probs_reorderd' during Viterbi decoding + # in a different function. + y_batch = V * torch.ones((B, U_max), dtype=torch.int64) + for b, y_utt in enumerate(y_list_batch): + U_utt = U_batch[b] + y_batch[b, :U_utt] = torch.tensor(y_utt) + + return ( + log_probs_batch, + y_batch, + T_batch, + U_batch, + token_info_batch, + word_info_batch, + segment_info_batch, + pred_text_batch, + ) diff --git a/tools/nemo_forced_aligner/utils/make_output_files.py b/tools/nemo_forced_aligner/utils/make_output_files.py new file mode 100644 index 000000000000..830bf476ff2f --- /dev/null +++ b/tools/nemo_forced_aligner/utils/make_output_files.py @@ -0,0 +1,210 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from pathlib import Path + +import soundfile as sf +from utils.constants import BLANK_TOKEN, SPACE_TOKEN + + +def _get_utt_id(audio_filepath, audio_filepath_parts_in_utt_id): + fp_parts = Path(audio_filepath).parts[-audio_filepath_parts_in_utt_id:] + utt_id = Path("_".join(fp_parts)).stem + utt_id = utt_id.replace(" ", "-") # replace any spaces in the filepath with dashes + return utt_id + + +def add_t_start_end_to_boundary_info(boundary_info_utt, alignment_utt): + """ + We use the list of alignments to add the timesteps where each token/word/segment is predicted to + start and end. + boundary_info_utt can be any one of the variables referred to as `token_info`, `word_info`, `segment_info` + in other parts of the code. + + e.g. the input boundary info could be + boundary_info_utt = [ + {'text': 'hi', 's_start': 1, 's_end': 3}, + {'text': 'world', 's_start': 7, 's_end': 15}, + {'text': 'hey', 's_start': 19, 's_end': 23}, + ] + + and the alignment could be + alignment_utt = [ 1, 1, 3, 3, 4, 5, 7, 7, 9, 10, 11, 12, 13, 15, 17, 17, 19, 21, 23, 23] + + in which case the output would be: + boundary_info_utt = [ + {'text': 'hi', 's_start': 1, 's_end': 3, 't_start': 0, 't_end': 3}, + {'text': 'world', 's_start': 7, 's_end': 15, 't_start': 6, 't_end': 13}, + {'text': 'hey', 's_start': 19, 's_end': 23, 't_start': 16, 't_end': 19}, + ] + """ + # first remove boundary_info of any items that are not in the alignment + # the only items we expect not to be in the alignment are blanks that the alignment chooses to skip + # we will iterate boundary_info in reverse order for this to make popping the items simple + s_in_alignment = set(alignment_utt) + for boundary_info_pointer in range(len(boundary_info_utt) - 1, -1, -1): + s_in_boundary_info = set( + range( + boundary_info_utt[boundary_info_pointer]["s_start"], + boundary_info_utt[boundary_info_pointer]["s_end"] + 1, + ) + ) + item_not_in_alignment = True + for s_ in s_in_boundary_info: + if s_ in s_in_alignment: + item_not_in_alignment = False + + if item_not_in_alignment: + boundary_info_utt.pop(boundary_info_pointer) + + # now update boundary_info with t_start and t_end + boundary_info_pointer = 0 + for t, s_at_t in enumerate(alignment_utt): + if s_at_t == boundary_info_utt[boundary_info_pointer]["s_start"]: + if "t_start" not in boundary_info_utt[boundary_info_pointer]: + # we have just reached the start of the word/token/segment in the alignment => update t_start + boundary_info_utt[boundary_info_pointer]["t_start"] = t + + if t < len(alignment_utt) - 1: # this if is to avoid accessing an index that is not in the list + if alignment_utt[t + 1] > boundary_info_utt[boundary_info_pointer]["s_end"]: + if "t_end" not in boundary_info_utt[boundary_info_pointer]: + boundary_info_utt[boundary_info_pointer]["t_end"] = t + + boundary_info_pointer += 1 + else: # i.e. t == len(alignment) - 1, i.e. we are a the final element in alignment + # add final t_end if we haven't already + if "t_end" not in boundary_info_utt[boundary_info_pointer]: + boundary_info_utt[boundary_info_pointer]["t_end"] = t + + if boundary_info_pointer == len(boundary_info_utt): + # we have finished populating boundary_info with t_start and t_end, + # but we might have some final remaining elements (blanks) in the alignment which we dont care about + # => break, so as not to cause issues trying to access boundary_info[boundary_info_pointer] + break + + return boundary_info_utt + + +def make_ctm( + boundary_info_batch, + alignments_batch, + manifest_lines_batch, + model, + model_downsample_factor, + output_dir, + remove_blank_tokens_from_ctm, + audio_filepath_parts_in_utt_id, + minimum_timestamp_duration, + audio_sr, +): + """ + Function to save CTM files for all the utterances in the incoming batch. + """ + + assert len(boundary_info_batch) == len(alignments_batch) == len(manifest_lines_batch) + # we also assume that utterances are in the same order in boundary_info_batch, alignments_batch + # and manifest_lines_batch - this should be the case unless there is a strange bug upstream in the + # code + + os.makedirs(output_dir, exist_ok=True) + + # the ratio to convert from timesteps (the units of 't_start' and 't_end' in boundary_info_utt) + # to the number of samples ('samples' in the sense of 16000 'samples' per second) + timestep_to_sample_ratio = model.preprocessor.featurizer.hop_length * model_downsample_factor + + for boundary_info_utt, alignment_utt, manifest_line in zip( + boundary_info_batch, alignments_batch, manifest_lines_batch + ): + + boundary_info_utt = add_t_start_end_to_boundary_info(boundary_info_utt, alignment_utt) + + # get utt_id that will be used for saving CTM file as .ctm + utt_id = _get_utt_id(manifest_line['audio_filepath'], audio_filepath_parts_in_utt_id) + + # get audio file duration if we will need it later + if minimum_timestamp_duration > 0: + with sf.SoundFile(manifest_line["audio_filepath"]) as f: + audio_file_duration = f.frames / f.samplerate + + with open(os.path.join(output_dir, f"{utt_id}.ctm"), "w") as f_ctm: + for boundary_info_ in boundary_info_utt: # loop over every token/word/segment + text = boundary_info_["text"] + start_sample = boundary_info_["t_start"] * timestep_to_sample_ratio + end_sample = (boundary_info_["t_end"] + 1) * timestep_to_sample_ratio - 1 + + start_time = start_sample / audio_sr + end_time = end_sample / audio_sr + + if minimum_timestamp_duration > 0 and minimum_timestamp_duration > end_time - start_time: + # make the predicted duration of the token/word/segment longer, growing it outwards equal + # amounts from the predicted center of the token/word/segment + token_mid_point = (start_time + end_time) / 2 + start_time = max(token_mid_point - minimum_timestamp_duration / 2, 0) + end_time = min(token_mid_point + minimum_timestamp_duration / 2, audio_file_duration) + + if not (text == BLANK_TOKEN and remove_blank_tokens_from_ctm): # don't save blanks if we don't want to + # replace any spaces with so we dont introduce extra space characters to our CTM files + text = text.replace(" ", SPACE_TOKEN) + + f_ctm.write(f"{utt_id} 1 {start_time:.2f} {end_time - start_time:.2f} {text}\n") + + return None + + +def make_new_manifest( + output_dir, + original_manifest_filepath, + additional_ctm_grouping_separator, + audio_filepath_parts_in_utt_id, + pred_text_all_lines, +): + """ + Function to save a new manifest with the same info as the original manifest, but also the paths to the + CTM files for each utterance and the "pred_text" if it was used for the alignment. + """ + if pred_text_all_lines: + with open(original_manifest_filepath, 'r') as f: + num_lines_in_manifest = sum(1 for _ in f) + + if not num_lines_in_manifest == len(pred_text_all_lines): + raise RuntimeError( + f"Number of lines in the original manifest ({num_lines_in_manifest}) does not match " + f"the number of pred_texts we have ({len(pred_text_all_lines)}). Something has gone wrong." + ) + + tgt_manifest_name = str(Path(original_manifest_filepath).stem) + "_with_ctm_paths.json" + tgt_manifest_filepath = str(Path(output_dir) / tgt_manifest_name) + + with open(original_manifest_filepath, 'r') as fin, open(tgt_manifest_filepath, 'w') as fout: + for i_line, line in enumerate(fin): + data = json.loads(line) + + utt_id = _get_utt_id(data["audio_filepath"], audio_filepath_parts_in_utt_id) + + data["token_level_ctm_filepath"] = str(Path(output_dir) / "tokens" / f"{utt_id}.ctm") + data["word_level_ctm_filepath"] = str(Path(output_dir) / "words" / f"{utt_id}.ctm") + + if additional_ctm_grouping_separator: + data["additional_segment_level_ctm_filepath"] = str( + Path(output_dir) / "additional_segments" / f"{utt_id}.ctm" + ) + + if pred_text_all_lines: + data['pred_text'] = pred_text_all_lines[i_line] + + new_line = json.dumps(data) + + fout.write(f"{new_line}\n") diff --git a/tools/nemo_forced_aligner/utils/viterbi_decoding.py b/tools/nemo_forced_aligner/utils/viterbi_decoding.py new file mode 100644 index 000000000000..bc9a45dda527 --- /dev/null +++ b/tools/nemo_forced_aligner/utils/viterbi_decoding.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from utils.constants import V_NEGATIVE_NUM + + +def viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device): + """ + Do Viterbi decoding with an efficient algorithm (the only for-loop in the 'forward pass' is over the time dimension). + Args: + log_probs_batch: tensor of shape (B, T_max, V). The parts of log_probs_batch which are 'padding' are filled + with 'V_NEGATIVE_NUM' - a large negative number which represents a very low probability. + y_batch: tensor of shape (B, U_max) - contains token IDs including blanks in every other position. The parts of + y_batch which are padding are filled with the number 'V'. V = the number of tokens in the vocabulary + 1 for + the blank token. + T_batch: tensor of shape (B, 1) - contains the durations of the log_probs_batch (so we can ignore the + parts of log_probs_batch which are padding) + U_batch: tensor of shape (B, 1) - contains the lengths of y_batch (so we can ignore the parts of y_batch + which are padding). + viterbi_device: the torch device on which Viterbi decoding will be done. + + Returns: + alignments_batch: list of lists containing locations for the tokens we align to at each timestep. + Looks like: [[0, 0, 1, 2, 2, 3, 3, ..., ], ..., [0, 1, 2, 2, 2, 3, 4, ....]]. + Each list inside alignments_batch is of length T_batch[location of utt in batch]. + """ + B, T_max, _ = log_probs_batch.shape + U_max = y_batch.shape[1] + + # transfer all tensors to viterbi_device + log_probs_batch = log_probs_batch.to(viterbi_device) + y_batch = y_batch.to(viterbi_device) + T_batch = T_batch.to(viterbi_device) + U_batch = U_batch.to(viterbi_device) + + # make tensor that we will put at timesteps beyond the duration of the audio + padding_for_log_probs = V_NEGATIVE_NUM * torch.ones((B, T_max, 1), device=viterbi_device) + # make log_probs_padded tensor of shape (B, T_max, V +1 ) where all of + # log_probs_padded[:,:,-1] is the 'V_NEGATIVE_NUM' + log_probs_padded = torch.cat((log_probs_batch, padding_for_log_probs), dim=2) + # make log_probs_reordered tensor of shape (B, T_max, U_max) + # it contains the log_probs for only the tokens that are in the Ground Truth, and in the order + # that they occur + log_probs_reordered = torch.gather(input=log_probs_padded, dim=2, index=y_batch.unsqueeze(1).repeat(1, T_max, 1)) + + # initialize tensors of viterbi probabilies and backpointers + v_matrix = V_NEGATIVE_NUM * torch.ones_like(log_probs_reordered) + backpointers = -999 * torch.ones_like(v_matrix) + v_matrix[:, 0, :2] = log_probs_reordered[:, 0, :2] + + # Make a letter_repetition_mask the same shape as y_batch + # the letter_repetition_mask will have 'True' where the token (including blanks) is the same + # as the token two places before it in the ground truth (and 'False everywhere else). + # We will use letter_repetition_mask to determine whether the Viterbi algorithm needs to look two tokens back or + # three tokens back + y_shifted_left = torch.roll(y_batch, shifts=2, dims=1) + letter_repetition_mask = y_batch - y_shifted_left + letter_repetition_mask[:, :2] = 1 # make sure dont apply mask to first 2 tokens + letter_repetition_mask = letter_repetition_mask == 0 + + # bp_absolute_template is a tensor we will need during the Viterbi decoding to convert our argmaxes from indices between 0 and 2, + # to indices in the range (0, U_max-1) indicating from which token the mostly path up to that point came from. + # it is a tensor of shape (B, U_max) that looks like + # bp_absolute_template = [ + # [0, 1, 2, ...,, U_max] + # [0, 1, 2, ...,, U_max] + # [0, 1, 2, ...,, U_max] + # ... rows repeated so there are B number of rows in total + # ] + bp_absolute_template = torch.arange(U_max, device=viterbi_device).unsqueeze(0).repeat(B, 1) + + for t in range(1, T_max): + + # e_current is a tensor of shape (B, U_max) of the log probs of every possible token at the current timestep + e_current = log_probs_reordered[:, t, :] + + # v_prev is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and in the same token position + v_prev = v_matrix[:, t - 1, :] + + # v_prev_shifted is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 1 token position back + v_prev_shifted = torch.roll(v_prev, shifts=1, dims=1) + # by doing a roll shift of size 1, we have brought the viterbi probability in the final token position to the + # first token position - let's overcome this by 'zeroing out' the probabilities in the firest token position + v_prev_shifted[:, 0] = V_NEGATIVE_NUM + + # v_prev_shifted2 is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 2 token position back + v_prev_shifted2 = torch.roll(v_prev, shifts=2, dims=1) + v_prev_shifted2[:, :2] = V_NEGATIVE_NUM # zero out as we did for v_prev_shifted + # use our letter_repetition_mask to remove the connections between 2 blanks (so we don't skip over a letter) + # and to remove the connections between 2 consective letters (so we don't skip over a blank) + v_prev_shifted2.masked_fill_(letter_repetition_mask, V_NEGATIVE_NUM) + + # we need this v_prev_dup tensor so we can calculated the viterbi probability of every possible + # token position simultaneously + v_prev_dup = torch.cat( + (v_prev.unsqueeze(2), v_prev_shifted.unsqueeze(2), v_prev_shifted2.unsqueeze(2),), dim=2, + ) + + # candidates_v_current are our candidate viterbi probabilities for every token position, from which + # we will pick the max and record the argmax + candidates_v_current = v_prev_dup + e_current.unsqueeze(2) + v_current, bp_relative = torch.max(candidates_v_current, dim=2) + + # convert our argmaxes from indices between 0 and 2, to indices in the range (0, U_max-1) indicating + # from which token the mostly path up to that point came from + bp_absolute = bp_absolute_template - bp_relative + + # update our tensors containing all the viterbi probabilites and backpointers + v_matrix[:, t, :] = v_current + backpointers[:, t, :] = bp_absolute + + # trace backpointers TODO: parallelize over batch_size + alignments_batch = [] + for b in range(B): + T_b = int(T_batch[b]) + U_b = int(U_batch[b]) + + final_state = int(torch.argmax(v_matrix[b, T_b - 1, U_b - 2 : U_b])) + U_b - 2 + alignment_b = [final_state] + for t in range(T_b - 1, 0, -1): + alignment_b.insert(0, int(backpointers[b, t, alignment_b[0]])) + alignments_batch.append(alignment_b) + + return alignments_batch