From 45baf762eceee097bd1ae9ef198c972a1fb98e6a Mon Sep 17 00:00:00 2001 From: popcornell Date: Fri, 8 Dec 2023 19:16:01 -0500 Subject: [PATCH 01/22] .ctm fix for data simulation Signed-off-by: popcornell --- nemo/collections/asr/parts/utils/data_simulation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index a9a1e10ae385..79384f28ac4c 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -774,7 +774,7 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = f"{session_name} {speaker_id} {align1} {align2} {word} 0\n" + text = f"{session_name} 0 {align1} {align2} {word} {speaker_id}\n" arr.append((align1, text)) return arr From c9a5b53891b57e7e010e47dc38a8e87c6e85c32b Mon Sep 17 00:00:00 2001 From: popcornell Date: Mon, 11 Dec 2023 16:06:12 -0500 Subject: [PATCH 02/22] .ctm fix, channel should be 1 not 0 Signed-off-by: popcornell --- nemo/collections/asr/parts/utils/data_simulation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 79384f28ac4c..fe74a02ff128 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -774,7 +774,7 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = f"{session_name} 0 {align1} {align2} {word} {speaker_id}\n" + text = f"{session_name} 1 {align1} {align2} {word} {speaker_id}\n" arr.append((align1, text)) return arr From ca2786402c1723e90e0e331da874b606a7d07e94 Mon Sep 17 00:00:00 2001 From: popcornell Date: Mon, 11 Dec 2023 16:10:18 -0500 Subject: [PATCH 03/22] .ctm fix, only two na, type and confidence Signed-off-by: popcornell --- nemo/collections/asr/parts/utils/data_simulation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index fe74a02ff128..3aae4dd13dea 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -774,7 +774,7 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = f"{session_name} 1 {align1} {align2} {word} {speaker_id}\n" + text = f"{session_name} 1 {align1} {align2} {word} {speaker_id}\n" arr.append((align1, text)) return arr From 8b62af8c3f26d004c8ef64b2af01a22652131009 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 11 Dec 2023 16:33:44 -0800 Subject: [PATCH 04/22] Revised all the parts in NeMo touching CTM files Signed-off-by: Taejin Park --- .../asr/parts/utils/data_simulation_utils.py | 4 +- .../asr/parts/utils/manifest_utils.py | 62 +++++++++++++++++++ .../create_alignment_manifest.py | 61 +++++++++++++----- .../asr/utils/test_data_simul_utils.py | 51 ++++++++++++++- .../utils/make_ctm_files.py | 13 +++- 5 files changed, 169 insertions(+), 22 deletions(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 79384f28ac4c..3ddd556fe822 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -25,7 +25,7 @@ from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, write_text +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, write_text, get_ctm_line from nemo.collections.asr.parts.utils.speaker_utils import labels_to_rttmfile from nemo.utils import logging @@ -774,7 +774,7 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = f"{session_name} 0 {align1} {align2} {word} {speaker_id}\n" + text = get_ctm_line(source=session_name, channel=0, beg_time=align1, duration=align2, token=word, conf=None type=None, speaker=speaker_id) arr.append((align1, text)) return arr diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 39bd6a8a24e7..745a60612f77 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -30,8 +30,70 @@ segments_manifest_to_subsegments_manifest, write_rttm2manifest, ) +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line + from nemo.utils.data_utils import DataStoreObject +def get_ctm_line(source: str, + channel: int, + beg_time: float, + duration: float, + token: str, + conf: float, + type_token: str, + speaker: str, + NA_token: str='NA', + UNK: str='unknown', + default_channel: str='1', + output_precision: int= 3 + ) -> str: + """ + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + + CTM Format: + + + Reference: + https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf + + Args: + source (str): is name of the source file, session name or utterance ID + channel (int): is channel number defaults to 1 + beg_time (float): is begin time of the word + duration (float): is duration of the word + token (str): Token or word for the current entry + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. + type (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when + the speaker has not been determined. + NA_token (str, optional): A token for . Defaults to ''. + output_precision (int, optional): The precision of the output floating point number. Defaults to 3. + + Returns: + str: Return a line in CTM format filled with the given information. + """ + if type(beg_time) != float: + beg_time = round(float(beg_time), output_precision) + if type(duration) != float: + duration = round(float(duration), output_precision) + if channel is not None and type(channel) != int: + channel = str(channel) + if conf is not None and type(conf) != float: + raise ValueError(f"`conf` must be a float, but got {type(conf)}") + if conf is not None and not (0 <= conf <= 1): + raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") + if type_token is not None and type(type_token) != str: + raise ValueError(f"`type` must be a string, but got {type(type)}") + if type_token is not None and type_token not in ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"]: + raise ValueError(f"`type` must be one of ['lex', 'frag', 'fp', 'un-lex', 'for-lex', 'non-lex', 'misc', 'noscore'], but got {type_token}") + if speaker is not None and type(speaker) != str: + raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") + channel = default_channel if channel is None else channel + conf = NA_token if conf is None else conf + speaker = NA_token if speaker is None else speaker + type_token = UNK if type_token is None else type_token + return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_token} {speaker}\n" def rreplace(s: str, old: str, new: str) -> str: """ diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index e2b15b03b842..dbd0c29a68d1 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -16,12 +16,37 @@ import os import shutil from pathlib import Path +from typing import List, Dict -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, get_ctm_line from nemo.utils import logging +def get_seg_info_from_ctm_line( + ctm: List[str], + output_precision: int, + speaker_index: int = 7, + beg_time_index: int = 2, + duration_index: int = 3, + ): + """ + Get time stamp information and speaker labels from CTM lines. + This is following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + + Args: + ctm (list): + output_precision (_type_): _description_ -def get_unaligned_files(unaligned_path): + Returns: + _type_: _description_ + """ + speaker_id = ctm[speaker_index] + start = float(ctm[beg_time_index]) + end = float(ctm[beg_time_index]) + float(ctm[duration_index]) + start = round(start, output_precision) + end = round(end, output_precision) + return start, end, speaker_id + +def get_unaligned_files(unaligned_path: str) -> List[str]: """ Get files without alignments in order to filter them out (as they cannot be used for data simulation). In the unaligned file, each line contains the file name and the reason for the unalignment, if necessary to specify. @@ -50,7 +75,6 @@ def get_unaligned_files(unaligned_path): skip_files.append(unaligned_file) return skip_files - def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_precision=3): """ Create new CTM entry (to write to output ctm file) @@ -71,7 +95,15 @@ def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_ # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 align1 = float(round(alignments[i - 1], output_precision)) align2 = float(round(alignments[i] - alignments[i - 1], output_precision,)) - text = f"{session_name} {speaker_id} {align1} {align2} {word} 0\n" + text = get_ctm_line(source=session_name, + channel=speaker_id, + beg_time=align1, + duration=align2, + token=word, + conf=0, + type='lex', + speaker=speaker_id, + output_precision=output_precision) arr.append((align1, text)) return arr @@ -95,7 +127,6 @@ def load_librispeech_alignment(alignment_filepath: str) -> dict: alignments[file_id] = (words, timestamps) return alignments - def create_librispeech_ctm_alignments( input_manifest_filepath, base_alignment_path, ctm_output_directory, libri_dataset_split ): @@ -206,11 +237,7 @@ def create_manifest_with_alignments( prev_end = 0 for i in range(len(lines)): ctm = lines[i].split(' ') - speaker_id = ctm[1] - start = float(ctm[2]) - end = float(ctm[2]) + float(ctm[3]) - start = round(start, output_precision) - end = round(end, output_precision) + speaker_id, start, end = get_seg_info_from_ctm_line(ctm=ctm, output_precision=output_precision) interval = start - prev_end if (i == 0 and interval > 0) or (i > 0 and interval > silence_dur_threshold): @@ -231,13 +258,13 @@ def create_manifest_with_alignments( end_times.append(f['duration']) # build target manifest entry - target_manifest.append({}) - target_manifest[tgt_i]['audio_filepath'] = f['audio_filepath'] - target_manifest[tgt_i]['duration'] = f['duration'] - target_manifest[tgt_i]['text'] = f['text'] - target_manifest[tgt_i]['words'] = words - target_manifest[tgt_i]['alignments'] = end_times - target_manifest[tgt_i]['speaker_id'] = speaker_id + target_manifest.append({'audio_filepath': f['audio_filepath'], + 'duration': f['duration'], + 'text': f['text'], + 'words': words, + 'alignments': end_times, + 'speaker_id': speaker_id + }) src_i += 1 tgt_i += 1 diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 4592043248ae..1934e907e8d2 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -29,6 +29,7 @@ normalize_audio, read_noise_manifest, ) +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line @pytest.fixture() @@ -128,6 +129,52 @@ def generate_words_and_alignments(sample_index): speaker_id = 'speaker_0' return words, alignments, speaker_id +class TestGetCtmLine: + @pytest.mark.unit + def test_valid_input(self): + # Test with completely valid inputs + result = get_ctm_line( + source="test_source", channel=1, beg_time=0.123, duration=0.456, + token="word", conf=0.789, type_token="lex", speaker="speaker1" + ) + expected = "test_source 1 0.123 0.456 word 0.789 lex speaker1\n" + assert result == expected, "Failed on valid input" + + @pytest.mark.unit + @pytest.mark.parametrize("beg_time, duration", [ + ("not a float", 1.0), + (1.0, "not a float"), + (1, 2.0), # Integers should be converted to float + (2.0, 3) # Same as above + ]) + def test_invalid_types_for_time_duration(self, beg_time, duration): + # Test with invalid types for beg_time and duration + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", channel=1, beg_time=beg_time, duration=duration, + token="word", conf=0.5, type_token="lex", speaker="speaker1" + ) + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, "not a float"]) + def test_invalid_conf_values(self, conf): + # Test with invalid values for conf + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", channel=1, beg_time=0.123, duration=0.456, + token="word", conf=conf, type_token="lex", speaker="speaker1" + ) + + @pytest.mark.unit + def test_default_values(self): + # Test with missing optional parameters + result = get_ctm_line( + source="test_source", channel=None, beg_time=0.123, duration=0.456, + token="word", conf=None, type_token=None, speaker=None + ) + expected = "test_source 1 0.123 0.456 word NA unknown NA\n" + assert result == expected, "Failed on default values" + class TestDataSimulatorUtils: # TODO: add tests for all util functions @@ -253,11 +300,11 @@ def test_create_new_ctm_entry(self, annotator): ) assert ctm_list[0] == ( alignments[1], - f"{session_name} {speaker_id} {alignments[1]} {alignments[1]-alignments[0]} {words[1]} 0\n", + f"{session_name} 1 {alignments[1]} {alignments[1]-alignments[0]} {words[1]} NA lex {speaker_id}\n", ) assert ctm_list[1] == ( alignments[2], - f"{session_name} {speaker_id} {alignments[2]} {alignments[2]-alignments[1]} {words[2]} 0\n", + f"{session_name} 1 {alignments[2]} {alignments[2]-alignments[1]} {words[2]} NA lex {speaker_id}\n", ) diff --git a/tools/nemo_forced_aligner/utils/make_ctm_files.py b/tools/nemo_forced_aligner/utils/make_ctm_files.py index f0326c07cf8f..52d7dc965655 100644 --- a/tools/nemo_forced_aligner/utils/make_ctm_files.py +++ b/tools/nemo_forced_aligner/utils/make_ctm_files.py @@ -17,6 +17,7 @@ import soundfile as sf from utils.constants import BLANK_TOKEN, SPACE_TOKEN from utils.data_prep import Segment, Word +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line def make_ctm_files( @@ -105,7 +106,17 @@ def make_ctm( # replace any spaces with so we dont introduce extra space characters to our CTM files text = text.replace(" ", SPACE_TOKEN) - f_ctm.write(f"{utt_obj.utt_id} 1 {start_time:.2f} {end_time - start_time:.2f} {text}\n") + ctm_line = get_ctm_line( + source=utt_obj.utt_id, + channel=1, + beg_time=start_time, + duration=end_time - start_time, + token=text, + conf=None, + type_token=None, + speaker=None, + ) + f_ctm.write(ctm_line) utt_obj.saved_output_files[f"{alignment_level}_level_ctm_filepath"] = os.path.join( output_dir, f"{utt_obj.utt_id}.ctm" From dcfb24a1398537ed031c2f372525aa732221c8bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 00:50:09 +0000 Subject: [PATCH 05/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../asr/parts/utils/data_simulation_utils.py | 27 +++++---- .../asr/parts/utils/manifest_utils.py | 44 ++++++++++----- .../create_alignment_manifest.py | 53 ++++++++++-------- .../asr/utils/test_data_simul_utils.py | 56 ++++++++++++++----- 4 files changed, 117 insertions(+), 63 deletions(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 27cf172772e9..fc3909b5ec17 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -25,7 +25,13 @@ from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, write_text, get_ctm_line +from nemo.collections.asr.parts.utils.manifest_utils import ( + get_ctm_line, + read_manifest, + write_ctm, + write_manifest, + write_text, +) from nemo.collections.asr.parts.utils.speaker_utils import labels_to_rttmfile from nemo.utils import logging @@ -774,15 +780,16 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = get_ctm_line(source=session_name, - channel=1, - beg_time=align1, - duration=align2, - token=word, - conf=None, - type_token='lex', - speaker=speaker_id, - ) + text = get_ctm_line( + source=session_name, + channel=1, + beg_time=align1, + duration=align2, + token=word, + conf=None, + type_token='lex', + speaker=speaker_id, + ) arr.append((align1, text)) return arr diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index c03b00783596..ee526ebb7098 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -32,19 +32,21 @@ ) from nemo.utils.data_utils import DataStoreObject -def get_ctm_line(source: str, - channel: int, - beg_time: float, - duration: float, - token: str, - conf: float, - type_token: str, - speaker: str, - NA_token: str='NA', - UNK: str='unknown', - default_channel: str='1', - output_precision: int= 3 - ) -> str: + +def get_ctm_line( + source: str, + channel: int, + beg_time: float, + duration: float, + token: str, + conf: float, + type_token: str, + speaker: str, + NA_token: str = 'NA', + UNK: str = 'unknown', + default_channel: str = '1', + output_precision: int = 3, +) -> str: """ Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. @@ -83,8 +85,19 @@ def get_ctm_line(source: str, raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") if type_token is not None and type(type_token) != str: raise ValueError(f"`type` must be a string, but got {type(type)}") - if type_token is not None and type_token not in ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"]: - raise ValueError(f"`type` must be one of ['lex', 'frag', 'fp', 'un-lex', 'for-lex', 'non-lex', 'misc', 'noscore'], but got {type_token}") + if type_token is not None and type_token not in [ + "lex", + "frag", + "fp", + "un-lex", + "for-lex", + "non-lex", + "misc", + "noscore", + ]: + raise ValueError( + f"`type` must be one of ['lex', 'frag', 'fp', 'un-lex', 'for-lex', 'non-lex', 'misc', 'noscore'], but got {type_token}" + ) if speaker is not None and type(speaker) != str: raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") channel = default_channel if channel is None else channel @@ -93,6 +106,7 @@ def get_ctm_line(source: str, type_token = UNK if type_token is None else type_token return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_token} {speaker}\n" + def rreplace(s: str, old: str, new: str) -> str: """ Replace end of string. diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index 60417ab7fa2f..13521915967f 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -16,18 +16,15 @@ import os import shutil from pathlib import Path -from typing import List, Dict +from typing import Dict, List -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, get_ctm_line +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line, read_manifest, write_ctm, write_manifest from nemo.utils import logging + def get_seg_info_from_ctm_line( - ctm: List[str], - output_precision: int, - speaker_index: int = 7, - beg_time_index: int = 2, - duration_index: int = 3, - ): + ctm: List[str], output_precision: int, speaker_index: int = 7, beg_time_index: int = 2, duration_index: int = 3, +): """ Get time stamp information and speaker labels from CTM lines. This is following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. @@ -46,6 +43,7 @@ def get_seg_info_from_ctm_line( end = round(end, output_precision) return start, end, speaker_id + def get_unaligned_files(unaligned_path: str) -> List[str]: """ Get files without alignments in order to filter them out (as they cannot be used for data simulation). @@ -75,6 +73,7 @@ def get_unaligned_files(unaligned_path: str) -> List[str]: skip_files.append(unaligned_file) return skip_files + def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_precision=3): """ Create new CTM entry (to write to output ctm file) @@ -95,15 +94,17 @@ def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_ # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 align1 = float(round(alignments[i - 1], output_precision)) align2 = float(round(alignments[i] - alignments[i - 1], output_precision,)) - text = get_ctm_line(source=session_name, - channel=speaker_id, - beg_time=align1, - duration=align2, - token=word, - conf=0, - type_token='lex', - speaker=speaker_id, - output_precision=output_precision) + text = get_ctm_line( + source=session_name, + channel=speaker_id, + beg_time=align1, + duration=align2, + token=word, + conf=0, + type_token='lex', + speaker=speaker_id, + output_precision=output_precision, + ) arr.append((align1, text)) return arr @@ -127,6 +128,7 @@ def load_librispeech_alignment(alignment_filepath: str) -> dict: alignments[file_id] = (words, timestamps) return alignments + def create_librispeech_ctm_alignments( input_manifest_filepath, base_alignment_path, ctm_output_directory, libri_dataset_split ): @@ -258,13 +260,16 @@ def create_manifest_with_alignments( end_times.append(f['duration']) # build target manifest entry - target_manifest.append({'audio_filepath': f['audio_filepath'], - 'duration': f['duration'], - 'text': f['text'], - 'words': words, - 'alignments': end_times, - 'speaker_id': speaker_id - }) + target_manifest.append( + { + 'audio_filepath': f['audio_filepath'], + 'duration': f['duration'], + 'text': f['text'], + 'words': words, + 'alignments': end_times, + 'speaker_id': speaker_id, + } + ) src_i += 1 tgt_i += 1 diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 1934e907e8d2..d5087bf8fe0d 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -129,30 +129,46 @@ def generate_words_and_alignments(sample_index): speaker_id = 'speaker_0' return words, alignments, speaker_id + class TestGetCtmLine: @pytest.mark.unit def test_valid_input(self): # Test with completely valid inputs result = get_ctm_line( - source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=0.789, type_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=0.123, + duration=0.456, + token="word", + conf=0.789, + type_token="lex", + speaker="speaker1", ) expected = "test_source 1 0.123 0.456 word 0.789 lex speaker1\n" assert result == expected, "Failed on valid input" @pytest.mark.unit - @pytest.mark.parametrize("beg_time, duration", [ - ("not a float", 1.0), - (1.0, "not a float"), - (1, 2.0), # Integers should be converted to float - (2.0, 3) # Same as above - ]) + @pytest.mark.parametrize( + "beg_time, duration", + [ + ("not a float", 1.0), + (1.0, "not a float"), + (1, 2.0), # Integers should be converted to float + (2.0, 3), # Same as above + ], + ) def test_invalid_types_for_time_duration(self, beg_time, duration): # Test with invalid types for beg_time and duration with pytest.raises(ValueError): get_ctm_line( - source="test_source", channel=1, beg_time=beg_time, duration=duration, - token="word", conf=0.5, type_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=beg_time, + duration=duration, + token="word", + conf=0.5, + type_token="lex", + speaker="speaker1", ) @pytest.mark.unit @@ -161,16 +177,28 @@ def test_invalid_conf_values(self, conf): # Test with invalid values for conf with pytest.raises(ValueError): get_ctm_line( - source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=conf, type_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_token="lex", + speaker="speaker1", ) @pytest.mark.unit def test_default_values(self): # Test with missing optional parameters result = get_ctm_line( - source="test_source", channel=None, beg_time=0.123, duration=0.456, - token="word", conf=None, type_token=None, speaker=None + source="test_source", + channel=None, + beg_time=0.123, + duration=0.456, + token="word", + conf=None, + type_token=None, + speaker=None, ) expected = "test_source 1 0.123 0.456 word NA unknown NA\n" assert result == expected, "Failed on default values" From ccb60055f16065fdc7b580d1bd5312ee91fdb735 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 11 Dec 2023 19:49:13 -0800 Subject: [PATCH 06/22] Updated tutorial, nemo-docs and tests for CTM formats Signed-off-by: Taejin Park --- .../asr/speaker_diarization/datasets.rst | 6 +-- .../offline_diar_with_asr_infer.py | 2 - .../asr/parts/utils/data_simulation_utils.py | 2 +- .../asr/parts/utils/manifest_utils.py | 17 +++---- .../asr/utils/test_data_simul_utils.py | 44 ++++++++++++++++--- .../ASR_with_SpeakerDiarization.ipynb | 24 +++++----- 6 files changed, 63 insertions(+), 32 deletions(-) diff --git a/docs/source/asr/speaker_diarization/datasets.rst b/docs/source/asr/speaker_diarization/datasets.rst index 9f1a43a58f11..952f426b4ff8 100644 --- a/docs/source/asr/speaker_diarization/datasets.rst +++ b/docs/source/asr/speaker_diarization/datasets.rst @@ -205,14 +205,14 @@ The following are descriptions about each field in an input manifest JSON file. ``ctm_filepath`` (Optional): - CTM file is used for the evaluation of word-level diarization results and word-timestamp alignment. CTM file follows the following convention: `` `` Since confidence is not required for evaluating diarization results, it can have any value. Note that the ```` should be exactly matched with speaker IDs in RTTM. + The CTM file is used for the evaluation of word-level diarization results and word-timestamp alignment. The CTM file follows this convention: `` ``. Note that the ```` should exactly match speaker IDs in RTTM. Since confidence is not required for evaluating diarization results, we assign ```` the value ``NA``. If the type of token is words, we assign ```` as ``lex``. Example lines of CTM file: .. code-block:: bash - TS3012d.Mix-Headset MTD046ID 12.879 0.32 okay 0 - TS3012d.Mix-Headset MTD046ID 13.203 0.24 yeah 0 + TS3012d.Mix-Headset 1 12.879 0.32 okay NA lex MTD046ID + TS3012d.Mix-Headset 1 13.203 0.24 yeah NA lex MTD046ID Evaluation on Benchmark Datasets diff --git a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py index da9671a2fc43..d15adb5b826a 100644 --- a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py +++ b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py @@ -63,8 +63,6 @@ def main(cfg): # If RTTM is provided and DER evaluation if diar_score is not None: - metric, mapping_dict, _ = diar_score - # Get session-level diarization error rate and speaker counting error der_results = OfflineDiarWithASR.gather_eval_results( diar_score=diar_score, diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 27cf172772e9..ca672d9e4055 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -780,7 +780,7 @@ def create_new_ctm_entry( duration=align2, token=word, conf=None, - type_token='lex', + type_of_token='lex', speaker=speaker_id, ) arr.append((align1, text)) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index c03b00783596..2843e5fa8abf 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -38,7 +38,7 @@ def get_ctm_line(source: str, duration: float, token: str, conf: float, - type_token: str, + type_of_token: str, speaker: str, NA_token: str='NA', UNK: str='unknown', @@ -49,7 +49,7 @@ def get_ctm_line(source: str, Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. CTM Format: - + Reference: https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf @@ -62,7 +62,7 @@ def get_ctm_line(source: str, token (str): Token or word for the current entry conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) when no confidence is computed and in the reference data. - type (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. @@ -71,6 +71,7 @@ def get_ctm_line(source: str, Returns: str: Return a line in CTM format filled with the given information. """ + VALID_TOKEN_TYPES = ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"] if type(beg_time) != float: beg_time = round(float(beg_time), output_precision) if type(duration) != float: @@ -81,17 +82,17 @@ def get_ctm_line(source: str, raise ValueError(f"`conf` must be a float, but got {type(conf)}") if conf is not None and not (0 <= conf <= 1): raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") - if type_token is not None and type(type_token) != str: + if type_of_token is not None and type(type_of_token) != str: raise ValueError(f"`type` must be a string, but got {type(type)}") - if type_token is not None and type_token not in ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"]: - raise ValueError(f"`type` must be one of ['lex', 'frag', 'fp', 'un-lex', 'for-lex', 'non-lex', 'misc', 'noscore'], but got {type_token}") + if type_of_token is not None and type_of_token not in VALID_TOKEN_TYPES: + raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token}") if speaker is not None and type(speaker) != str: raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") channel = default_channel if channel is None else channel conf = NA_token if conf is None else conf speaker = NA_token if speaker is None else speaker - type_token = UNK if type_token is None else type_token - return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_token} {speaker}\n" + type_of_token = UNK if type_of_token is None else type_of_token + return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_of_token} {speaker}\n" def rreplace(s: str, old: str, new: str) -> str: """ diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 1934e907e8d2..eef950cc9e9e 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -130,12 +130,45 @@ def generate_words_and_alignments(sample_index): return words, alignments, speaker_id class TestGetCtmLine: + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0, 1]) + def test_wrong_type_conf_values(self, conf): + # Test with wrong integer confidence values + with pytest.raises(ValueError): + result = get_ctm_line( + source="test_source", channel=1, beg_time=0.123, duration=0.456, + token="word", conf=conf, type_of_token="lex", speaker="speaker1" + ) + expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.001, 0.999]) + def test_valid_conf_values(self, conf): + # Test with valid confidence values + result = get_ctm_line( + source="test_source", channel=1, beg_time=0.123, duration=0.456, + token="word", conf=conf, type_of_token="lex", speaker="speaker1" + ) + expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, 2, -1, 100, -100]) + def test_invalid_conf_ranges(self, conf): + # Test with invalid confidence values + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", channel=1, beg_time=0.123, duration=0.456, + token="word", conf=conf, type_of_token="lex", speaker="speaker1" + ) + @pytest.mark.unit def test_valid_input(self): # Test with completely valid inputs result = get_ctm_line( source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=0.789, type_token="lex", speaker="speaker1" + token="word", conf=0.789, type_of_token="lex", speaker="speaker1" ) expected = "test_source 1 0.123 0.456 word 0.789 lex speaker1\n" assert result == expected, "Failed on valid input" @@ -144,15 +177,14 @@ def test_valid_input(self): @pytest.mark.parametrize("beg_time, duration", [ ("not a float", 1.0), (1.0, "not a float"), - (1, 2.0), # Integers should be converted to float - (2.0, 3) # Same as above + ("not 1.01", "not 1.67"), ]) def test_invalid_types_for_time_duration(self, beg_time, duration): # Test with invalid types for beg_time and duration with pytest.raises(ValueError): get_ctm_line( source="test_source", channel=1, beg_time=beg_time, duration=duration, - token="word", conf=0.5, type_token="lex", speaker="speaker1" + token="word", conf=0.5, type_of_token="lex", speaker="speaker1" ) @pytest.mark.unit @@ -162,7 +194,7 @@ def test_invalid_conf_values(self, conf): with pytest.raises(ValueError): get_ctm_line( source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=conf, type_token="lex", speaker="speaker1" + token="word", conf=conf, type_of_token="lex", speaker="speaker1" ) @pytest.mark.unit @@ -170,7 +202,7 @@ def test_default_values(self): # Test with missing optional parameters result = get_ctm_line( source="test_source", channel=None, beg_time=0.123, duration=0.456, - token="word", conf=None, type_token=None, speaker=None + token="word", conf=None, type_of_token=None, speaker=None ) expected = "test_source 1 0.123 0.456 word NA unknown NA\n" assert result == expected, "Failed on default values" diff --git a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb index 005a198720ad..0fb2b62610a6 100644 --- a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb +++ b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb @@ -541,7 +541,7 @@ "- Example:\n", "`diar_session_123 1 13.2 0.25 hi 0 lex speaker_3`\n", "\n", - "For the purpose of creating the reference annotations, we use `` for speaker labels and assign 0 to ``. The reference CTM file for the `an4_diarize_test.wav` looks like the following:" + "For the purpose of creating the reference annotations, we set `1` for `` and assign \"`NA`\" to ``, \"`lex`\" to ``. The reference CTM file for the `an4_diarize_test.wav` looks like the following:" ] }, { @@ -551,16 +551,16 @@ "outputs": [], "source": [ "an4_diarize_test_ctm = \\\n", - "[\"an4_diarize_test 1 0.4 0.51 eleven 0 lex speaker_0\",\n", - "\"an4_diarize_test 1 0.95 0.32 twenty 0 lex speaker_0\",\n", - "\"an4_diarize_test 1 1.35 0.55 seven 0 lex speaker_0\",\n", - "\"an4_diarize_test 1 1.96 0.32 fifty 0 lex speaker_0\",\n", - "\"an4_diarize_test 1 2.32 0.75 seven 0 lex speaker_0\",\n", - "\"an4_diarize_test 1 3.12 0.42 october 0 lex speaker_1\",\n", - "\"an4_diarize_test 1 3.6 0.28 twenty 0 lex speaker_1\",\n", - "\"an4_diarize_test 1 3.95 0.35 four 0 lex speaker_1\",\n", - "\"an4_diarize_test 1 4.3 0.31 nineteen 0 lex speaker_1\",\n", - "\"an4_diarize_test 1 4.65 0.35 seventy 0 lex speaker_1\",]" + "[\"an4_diarize_test 1 0.4 0.51 eleven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 0.95 0.32 twenty NA lex speaker_0\",\n", + "\"an4_diarize_test 1 1.35 0.55 seven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 1.96 0.32 fifty NA lex speaker_0\",\n", + "\"an4_diarize_test 1 2.32 0.75 seven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 3.12 0.42 october NA lex speaker_1\",\n", + "\"an4_diarize_test 1 3.6 0.28 twenty NA lex speaker_1\",\n", + "\"an4_diarize_test 1 3.95 0.35 four NA lex speaker_1\",\n", + "\"an4_diarize_test 1 4.3 0.31 nineteen NA lex speaker_1\",\n", + "\"an4_diarize_test 1 4.65 0.35 seventy NA lex speaker_1\",]" ] }, { @@ -959,7 +959,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.13" } }, "nbformat": 4, From e76e85a1dad3a03b77fb6fa8f928155a89935eb8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 03:55:17 +0000 Subject: [PATCH 07/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../asr/utils/test_data_simul_utils.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 587882b7950b..cfacb6dea548 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -137,19 +137,31 @@ def test_wrong_type_conf_values(self, conf): # Test with wrong integer confidence values with pytest.raises(ValueError): result = get_ctm_line( - source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=conf, type_of_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", ) expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" assert result == expected, f"Failed on valid conf value {conf}" - + @pytest.mark.unit @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.001, 0.999]) def test_valid_conf_values(self, conf): # Test with valid confidence values result = get_ctm_line( - source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=conf, type_of_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", ) expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" assert result == expected, f"Failed on valid conf value {conf}" @@ -160,10 +172,16 @@ def test_invalid_conf_ranges(self, conf): # Test with invalid confidence values with pytest.raises(ValueError): get_ctm_line( - source="test_source", channel=1, beg_time=0.123, duration=0.456, - token="word", conf=conf, type_of_token="lex", speaker="speaker1" + source="test_source", + channel=1, + beg_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", ) - + @pytest.mark.unit def test_valid_input(self): # Test with completely valid inputs From d44749260b2d023c884ee99f12fab25d3b42ce62 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 11 Dec 2023 20:04:45 -0800 Subject: [PATCH 08/22] Fixed the docstrings in create_alignment_manifest.py Signed-off-by: Taejin Park --- .../create_alignment_manifest.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index d186781c1833..a1bba273e222 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -16,29 +16,31 @@ import os import shutil from pathlib import Path -from typing import Dict, List +from typing import List from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line, read_manifest, write_ctm, write_manifest from nemo.utils import logging def get_seg_info_from_ctm_line( - ctm: List[str], output_precision: int, speaker_index: int = 7, beg_time_index: int = 2, duration_index: int = 3, + ctm_list: List[str], output_precision: int, speaker_index: int = 7, beg_time_index: int = 2, duration_index: int = 3, ): """ Get time stamp information and speaker labels from CTM lines. This is following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. Args: - ctm (list): - output_precision (_type_): _description_ + ctm_list (list): List containing CTM items. e.g.: ['sw02001-A', '1', '0.000', '0.200', 'hello', '0.98', 'lex', 'speaker3'] + output_precision (int): Precision for CTM outputs in integer. Returns: - _type_: _description_ + start (float): Start time of the segment. + end (float): End time of the segment. + speaker_id (str): Speaker ID of the segment. """ - speaker_id = ctm[speaker_index] - start = float(ctm[beg_time_index]) - end = float(ctm[beg_time_index]) + float(ctm[duration_index]) + speaker_id = ctm_list[speaker_index] + start = float(ctm_list[beg_time_index]) + end = float(ctm_list[beg_time_index]) + float(ctm_list[duration_index]) start = round(start, output_precision) end = round(end, output_precision) return start, end, speaker_id @@ -239,7 +241,7 @@ def create_manifest_with_alignments( prev_end = 0 for i in range(len(lines)): ctm = lines[i].split(' ') - speaker_id, start, end = get_seg_info_from_ctm_line(ctm=ctm, output_precision=output_precision) + speaker_id, start, end = get_seg_info_from_ctm_line(ctm_list=ctm, output_precision=output_precision) interval = start - prev_end if (i == 0 and interval > 0) or (i > 0 and interval > silence_dur_threshold): From 971485498f6e9b21ed420cd0295865a374c4f316 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 04:05:53 +0000 Subject: [PATCH 09/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/speaker_tasks/create_alignment_manifest.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index a1bba273e222..a239229b05ed 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -23,7 +23,11 @@ def get_seg_info_from_ctm_line( - ctm_list: List[str], output_precision: int, speaker_index: int = 7, beg_time_index: int = 2, duration_index: int = 3, + ctm_list: List[str], + output_precision: int, + speaker_index: int = 7, + beg_time_index: int = 2, + duration_index: int = 3, ): """ Get time stamp information and speaker labels from CTM lines. From 0de3390b102473079b3d9942e4e35ce2ecbd8abf Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 11 Dec 2023 20:57:50 -0800 Subject: [PATCH 10/22] Some missing refactored variables for type_of_token Signed-off-by: Taejin Park --- tests/collections/asr/utils/test_data_simul_utils.py | 8 ++++---- tools/nemo_forced_aligner/utils/make_ctm_files.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index cfacb6dea548..934be2ebff1a 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -192,7 +192,7 @@ def test_valid_input(self): duration=0.456, token="word", conf=0.789, - type_token="lex", + type_of_token="lex", speaker="speaker1", ) expected = "test_source 1 0.123 0.456 word 0.789 lex speaker1\n" @@ -218,7 +218,7 @@ def test_invalid_types_for_time_duration(self, beg_time, duration): duration=duration, token="word", conf=0.5, - type_token="lex", + type_of_token="lex", speaker="speaker1", ) @@ -234,7 +234,7 @@ def test_invalid_conf_values(self, conf): duration=0.456, token="word", conf=conf, - type_token="lex", + type_of_token="lex", speaker="speaker1", ) @@ -248,7 +248,7 @@ def test_default_values(self): duration=0.456, token="word", conf=None, - type_token=None, + type_of_token=None, speaker=None, ) expected = "test_source 1 0.123 0.456 word NA unknown NA\n" diff --git a/tools/nemo_forced_aligner/utils/make_ctm_files.py b/tools/nemo_forced_aligner/utils/make_ctm_files.py index 52d7dc965655..fbc5b7f0e15c 100644 --- a/tools/nemo_forced_aligner/utils/make_ctm_files.py +++ b/tools/nemo_forced_aligner/utils/make_ctm_files.py @@ -113,7 +113,7 @@ def make_ctm( duration=end_time - start_time, token=text, conf=None, - type_token=None, + type_of_token='lex', speaker=None, ) f_ctm.write(ctm_line) From 6e6344e32d7291537cbdc7b506e3e646a021cc27 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 11 Dec 2023 21:51:39 -0800 Subject: [PATCH 11/22] Another un-fixed part in data_simulation_utils.py Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/data_simulation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index fc3909b5ec17..3f9ead1b99a0 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -787,7 +787,7 @@ def create_new_ctm_entry( duration=align2, token=word, conf=None, - type_token='lex', + type_of_token='lex', speaker=speaker_id, ) arr.append((align1, text)) From 8d94c368db82467c31e3e24fe3e558d8a96af624 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 13 Dec 2023 15:42:45 -0800 Subject: [PATCH 12/22] Reflected comments from PR Signed-off-by: Taejin Park --- .../asr/parts/utils/manifest_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index c070564225d3..eaf44bbcb1fb 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -74,10 +74,17 @@ def get_ctm_line( str: Return a line in CTM format filled with the given information. """ VALID_TOKEN_TYPES = ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"] - if type(beg_time) != float: - beg_time = round(float(beg_time), output_precision) - if type(duration) != float: - duration = round(float(duration), output_precision) + + if type(beg_time) == str and beg_time.replace('.','',1).isdigit(): + beg_time = float(beg_time) + elif type(beg_time) != float: + raise ValueError(f"`beg_time` must be a float or str containing float, but got {type(beg_time)}") + + if type(duration) == str and duration.replace('.','',1).isdigit(): + duration = float(duration) + elif type(duration) != float: + raise ValueError(f"`duration` must be a float or str containing float, but got {type(duration)}") + if channel is not None and type(channel) != int: channel = str(channel) if conf is not None and type(conf) != float: @@ -90,10 +97,12 @@ def get_ctm_line( raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token}") if speaker is not None and type(speaker) != str: raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") + channel = default_channel if channel is None else channel conf = NA_token if conf is None else conf speaker = NA_token if speaker is None else speaker type_of_token = UNK if type_of_token is None else type_of_token + beg_time, duration = round(beg_time, output_precision), round(float(duration), output_precision) return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_of_token} {speaker}\n" From b4bf97093b99a6d1039bc41c537646bdd3339cf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Dec 2023 23:44:08 +0000 Subject: [PATCH 13/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/asr/parts/utils/manifest_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index eaf44bbcb1fb..4795b13a9024 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -74,17 +74,17 @@ def get_ctm_line( str: Return a line in CTM format filled with the given information. """ VALID_TOKEN_TYPES = ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"] - - if type(beg_time) == str and beg_time.replace('.','',1).isdigit(): + + if type(beg_time) == str and beg_time.replace('.', '', 1).isdigit(): beg_time = float(beg_time) elif type(beg_time) != float: raise ValueError(f"`beg_time` must be a float or str containing float, but got {type(beg_time)}") - - if type(duration) == str and duration.replace('.','',1).isdigit(): + + if type(duration) == str and duration.replace('.', '', 1).isdigit(): duration = float(duration) elif type(duration) != float: raise ValueError(f"`duration` must be a float or str containing float, but got {type(duration)}") - + if channel is not None and type(channel) != int: channel = str(channel) if conf is not None and type(conf) != float: @@ -97,7 +97,7 @@ def get_ctm_line( raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token}") if speaker is not None and type(speaker) != str: raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") - + channel = default_channel if channel is None else channel conf = NA_token if conf is None else conf speaker = NA_token if speaker is None else speaker From 724adf3ee8434770412a4ded2c8a8f196c0fd3c9 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 13 Dec 2023 15:50:35 -0800 Subject: [PATCH 14/22] Reflected another precision related comments from PR Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/manifest_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index eaf44bbcb1fb..15823d80c603 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -45,7 +45,7 @@ def get_ctm_line( NA_token: str = 'NA', UNK: str = 'unknown', default_channel: str = '1', - output_precision: int = 3, + output_precision: int = 2, ) -> str: """ Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. From 86d5198a19c8e6152673e27214a18e640f20594e Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 13 Dec 2023 15:55:08 -0800 Subject: [PATCH 15/22] Updated tests to use decimal rounding of 2 Signed-off-by: Taejin Park --- tests/collections/asr/utils/test_data_simul_utils.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 934be2ebff1a..9264078329cc 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -146,11 +146,11 @@ def test_wrong_type_conf_values(self, conf): type_of_token="lex", speaker="speaker1", ) - expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" + expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" assert result == expected, f"Failed on valid conf value {conf}" @pytest.mark.unit - @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.001, 0.999]) + @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.01, 0.99]) def test_valid_conf_values(self, conf): # Test with valid confidence values result = get_ctm_line( @@ -163,7 +163,7 @@ def test_valid_conf_values(self, conf): type_of_token="lex", speaker="speaker1", ) - expected = f"test_source 1 0.123 0.456 word {conf} lex speaker1\n" + expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" assert result == expected, f"Failed on valid conf value {conf}" @pytest.mark.unit @@ -195,7 +195,7 @@ def test_valid_input(self): type_of_token="lex", speaker="speaker1", ) - expected = "test_source 1 0.123 0.456 word 0.789 lex speaker1\n" + expected = "test_source 1 0.12 0.46 word 0.789 lex speaker1\n" assert result == expected, "Failed on valid input" @pytest.mark.unit @@ -251,7 +251,7 @@ def test_default_values(self): type_of_token=None, speaker=None, ) - expected = "test_source 1 0.123 0.456 word NA unknown NA\n" + expected = "test_source 1 0.12 0.46 word NA unknown NA\n" assert result == expected, "Failed on default values" @@ -273,7 +273,6 @@ def test_normalize_audio(self, sample_len, gain): norm_array = normalize_audio(array_input) assert torch.max(torch.abs(norm_array)) == 1.0 assert torch.min(torch.abs(norm_array)) < 1.0 - @pytest.mark.parametrize("output_dir", [os.path.join(os.getcwd(), "test_dir")]) def test_get_cleaned_base_path(self, output_dir): result_path = get_cleaned_base_path(output_dir, overwrite_output=True) From 604be4d9ff6eb22022c5386fb80d3797211535c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Dec 2023 23:56:10 +0000 Subject: [PATCH 16/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/collections/asr/utils/test_data_simul_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 9264078329cc..110e6dd54243 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -273,6 +273,7 @@ def test_normalize_audio(self, sample_len, gain): norm_array = normalize_audio(array_input) assert torch.max(torch.abs(norm_array)) == 1.0 assert torch.min(torch.abs(norm_array)) < 1.0 + @pytest.mark.parametrize("output_dir", [os.path.join(os.getcwd(), "test_dir")]) def test_get_cleaned_base_path(self, output_dir): result_path = get_cleaned_base_path(output_dir, overwrite_output=True) From 28cc0d5471fe212973f4439146b021f92dc55feb Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Fri, 15 Dec 2023 14:58:19 -0800 Subject: [PATCH 17/22] Changed beg_time to start_time and fixed unit tests Signed-off-by: Taejin Park --- .../asr/parts/utils/data_simulation_utils.py | 2 +- .../asr/parts/utils/manifest_utils.py | 44 ++++++++-- .../create_alignment_manifest.py | 8 +- .../asr/utils/test_data_simul_utils.py | 85 ++++++++++++++----- .../utils/make_ctm_files.py | 2 +- 5 files changed, 105 insertions(+), 36 deletions(-) diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index 3f9ead1b99a0..66b21c2478a0 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -783,7 +783,7 @@ def create_new_ctm_entry( text = get_ctm_line( source=session_name, channel=1, - beg_time=align1, + start_time=align1, duration=align2, token=word, conf=None, diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 79965c17360e..a9e868bd0faa 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -33,10 +33,31 @@ from nemo.utils.data_utils import DataStoreObject +def get_rounded_str_float( + num: float, + output_precision: int, + min_precision=1, + max_precision=3 +)-> str: + """ + Get a string of a float number with rounded precision. + + Args: + num (float): float number to round + output_precision (int): precision of the output floating point number + min_precision (int, optional): Minimum precision of the output floating point number. Defaults to 1. + max_precision (int, optional): Maximum precision of the output floating point number. Defaults to 3. + + Returns: + (str): Return a string of a float number with rounded precision. + """ + output_precision = min(max_precision, max(min_precision, output_precision)) + return f"{num:.{output_precision}f}" + def get_ctm_line( source: str, channel: int, - beg_time: float, + start_time: float, duration: float, token: str, conf: float, @@ -59,7 +80,7 @@ def get_ctm_line( Args: source (str): is name of the source file, session name or utterance ID channel (int): is channel number defaults to 1 - beg_time (float): is begin time of the word + start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) @@ -75,16 +96,21 @@ def get_ctm_line( """ VALID_TOKEN_TYPES = ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"] - if type(beg_time) == str and beg_time.replace('.', '', 1).isdigit(): - beg_time = float(beg_time) - elif type(beg_time) != float: - raise ValueError(f"`beg_time` must be a float or str containing float, but got {type(beg_time)}") + if type(start_time) == str and start_time.replace('.', '', 1).isdigit(): + start_time = float(start_time) + elif type(start_time) != float: + raise ValueError(f"`start_time` must be a float or str containing float, but got {type(start_time)}") if type(duration) == str and duration.replace('.', '', 1).isdigit(): duration = float(duration) elif type(duration) != float: raise ValueError(f"`duration` must be a float or str containing float, but got {type(duration)}") + if type(conf) == str and conf.replace('.', '', 1).isdigit(): + conf = float(conf) + elif type(duration) != float: + raise ValueError(f"`conf` must be a float or str containing float, but got {type(conf)}") + if channel is not None and type(channel) != int: channel = str(channel) if conf is not None and type(conf) != float: @@ -102,8 +128,10 @@ def get_ctm_line( conf = NA_token if conf is None else conf speaker = NA_token if speaker is None else speaker type_of_token = UNK if type_of_token is None else type_of_token - beg_time, duration = round(beg_time, output_precision), round(float(duration), output_precision) - return f"{source} {channel} {beg_time} {duration} {token} {conf} {type_of_token} {speaker}\n" + start_time = get_rounded_str_float(start_time, output_precision) + duration = get_rounded_str_float(duration, output_precision) + conf = get_rounded_str_float(conf, output_precision) if conf != NA_token else conf + return f"{source} {channel} {start_time} {duration} {token} {conf} {type_of_token} {speaker}\n" def rreplace(s: str, old: str, new: str) -> str: diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index a239229b05ed..91825ac0d7e9 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -26,7 +26,7 @@ def get_seg_info_from_ctm_line( ctm_list: List[str], output_precision: int, speaker_index: int = 7, - beg_time_index: int = 2, + start_time_index: int = 2, duration_index: int = 3, ): """ @@ -43,8 +43,8 @@ def get_seg_info_from_ctm_line( speaker_id (str): Speaker ID of the segment. """ speaker_id = ctm_list[speaker_index] - start = float(ctm_list[beg_time_index]) - end = float(ctm_list[beg_time_index]) + float(ctm_list[duration_index]) + start = float(ctm_list[start_time_index]) + end = float(ctm_list[start_time_index]) + float(ctm_list[duration_index]) start = round(start, output_precision) end = round(end, output_precision) return start, end, speaker_id @@ -103,7 +103,7 @@ def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_ text = get_ctm_line( source=session_name, channel=speaker_id, - beg_time=align1, + start_time=align1, duration=align2, token=word, conf=0, diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 9264078329cc..2fed51578fa4 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -139,7 +139,7 @@ def test_wrong_type_conf_values(self, conf): result = get_ctm_line( source="test_source", channel=1, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=conf, @@ -153,17 +153,19 @@ def test_wrong_type_conf_values(self, conf): @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.01, 0.99]) def test_valid_conf_values(self, conf): # Test with valid confidence values + output_precision = 2 result = get_ctm_line( source="test_source", channel=1, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=conf, type_of_token="lex", speaker="speaker1", + output_precision=output_precision, ) - expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" + expected = "test_source 1 0.12 0.46 word" + f" {conf:.{output_precision}f} lex speaker1\n" assert result == expected, f"Failed on valid conf value {conf}" @pytest.mark.unit @@ -174,7 +176,7 @@ def test_invalid_conf_ranges(self, conf): get_ctm_line( source="test_source", channel=1, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=conf, @@ -182,25 +184,53 @@ def test_invalid_conf_ranges(self, conf): speaker="speaker1", ) + @pytest.mark.unit + @pytest.mark.parametrize("start_time, duration, output_precision", [ + (0.123, 0.456, 2), + (1.0, 2.0, 1), + (0.0, 0.0, 2), + (0.01, 0.99, 3), + (1.23, 4.56, 2) + ]) + def test_valid_start_time_duration_with_precision(self, start_time, duration, output_precision): + # Test with valid beginning time, duration values and output precision + confidence = 0.5 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=confidence, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected_start_time = f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + expected_duration = f"{duration:.{output_precision}f}" # Adjusted to match the output format with precision + expected_confidence = f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + expected = f"test_source 1 {expected_start_time} {expected_duration} word {expected_confidence} lex speaker1\n" + assert result == expected, f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" + @pytest.mark.unit def test_valid_input(self): # Test with completely valid inputs result = get_ctm_line( source="test_source", channel=1, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=0.789, type_of_token="lex", speaker="speaker1", ) - expected = "test_source 1 0.12 0.46 word 0.789 lex speaker1\n" + expected = "test_source 1 0.12 0.46 word 0.79 lex speaker1\n" assert result == expected, "Failed on valid input" @pytest.mark.unit @pytest.mark.parametrize( - "beg_time, duration", + "start_time, duration", [ ("not a float", 1.0), (1.0, "not a float"), @@ -208,13 +238,13 @@ def test_valid_input(self): (2.0, 3), # Same as above ], ) - def test_invalid_types_for_time_duration(self, beg_time, duration): - # Test with invalid types for beg_time and duration + def test_invalid_types_for_time_duration(self, start_time, duration): + # Test with invalid types for start_time and duration with pytest.raises(ValueError): get_ctm_line( source="test_source", channel=1, - beg_time=beg_time, + start_time=start_time, duration=duration, token="word", conf=0.5, @@ -230,7 +260,7 @@ def test_invalid_conf_values(self, conf): get_ctm_line( source="test_source", channel=1, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=conf, @@ -244,7 +274,7 @@ def test_default_values(self): result = get_ctm_line( source="test_source", channel=None, - beg_time=0.123, + start_time=0.123, duration=0.456, token="word", conf=None, @@ -371,19 +401,30 @@ def test_create_new_json_entry(self, annotator): def test_create_new_ctm_entry(self, annotator): words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) - start = alignments[0] session_name = 'test_session' ctm_list = annotator.create_new_ctm_entry( - words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=start - ) - assert ctm_list[0] == ( - alignments[1], - f"{session_name} 1 {alignments[1]} {alignments[1]-alignments[0]} {words[1]} NA lex {speaker_id}\n", - ) - assert ctm_list[1] == ( - alignments[2], - f"{session_name} 1 {alignments[2]} {alignments[2]-alignments[1]} {words[2]} NA lex {speaker_id}\n", + words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] ) + assert ctm_list[0] == (alignments[1], get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[1], + duration=float(alignments[1] - alignments[0]), + token=words[1], + conf=None, + type_of_token='lex', + speaker=speaker_id, + )) + assert ctm_list[1] == (alignments[2], get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[2], + duration=float(alignments[2] - alignments[1]), + token=words[2], + conf=None, + type_of_token='lex', + speaker=speaker_id, + )) class TestSpeechSampler: diff --git a/tools/nemo_forced_aligner/utils/make_ctm_files.py b/tools/nemo_forced_aligner/utils/make_ctm_files.py index fbc5b7f0e15c..17b734f20595 100644 --- a/tools/nemo_forced_aligner/utils/make_ctm_files.py +++ b/tools/nemo_forced_aligner/utils/make_ctm_files.py @@ -109,7 +109,7 @@ def make_ctm( ctm_line = get_ctm_line( source=utt_obj.utt_id, channel=1, - beg_time=start_time, + start_time=start_time, duration=end_time - start_time, token=text, conf=None, From 4a97f442b58ca4be1566d94ee4f5595f3f8bf9db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Dec 2023 22:59:34 +0000 Subject: [PATCH 18/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../asr/parts/utils/manifest_utils.py | 8 +-- .../asr/utils/test_data_simul_utils.py | 69 +++++++++++-------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index a9e868bd0faa..dbd385788793 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -33,12 +33,7 @@ from nemo.utils.data_utils import DataStoreObject -def get_rounded_str_float( - num: float, - output_precision: int, - min_precision=1, - max_precision=3 -)-> str: +def get_rounded_str_float(num: float, output_precision: int, min_precision=1, max_precision=3) -> str: """ Get a string of a float number with rounded precision. @@ -54,6 +49,7 @@ def get_rounded_str_float( output_precision = min(max_precision, max(min_precision, output_precision)) return f"{num:.{output_precision}f}" + def get_ctm_line( source: str, channel: int, diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 5f9d05d7baae..295b79c76d18 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -185,13 +185,10 @@ def test_invalid_conf_ranges(self, conf): ) @pytest.mark.unit - @pytest.mark.parametrize("start_time, duration, output_precision", [ - (0.123, 0.456, 2), - (1.0, 2.0, 1), - (0.0, 0.0, 2), - (0.01, 0.99, 3), - (1.23, 4.56, 2) - ]) + @pytest.mark.parametrize( + "start_time, duration, output_precision", + [(0.123, 0.456, 2), (1.0, 2.0, 1), (0.0, 0.0, 2), (0.01, 0.99, 3), (1.23, 4.56, 2)], + ) def test_valid_start_time_duration_with_precision(self, start_time, duration, output_precision): # Test with valid beginning time, duration values and output precision confidence = 0.5 @@ -206,11 +203,17 @@ def test_valid_start_time_duration_with_precision(self, start_time, duration, ou speaker="speaker1", output_precision=output_precision, ) - expected_start_time = f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + expected_start_time = ( + f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + ) expected_duration = f"{duration:.{output_precision}f}" # Adjusted to match the output format with precision - expected_confidence = f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + expected_confidence = ( + f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + ) expected = f"test_source 1 {expected_start_time} {expected_duration} word {expected_confidence} lex speaker1\n" - assert result == expected, f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" + assert ( + result == expected + ), f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" @pytest.mark.unit def test_valid_input(self): @@ -406,26 +409,32 @@ def test_create_new_ctm_entry(self, annotator): ctm_list = annotator.create_new_ctm_entry( words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] ) - assert ctm_list[0] == (alignments[1], get_ctm_line( - source=session_name, - channel="1", - start_time=alignments[1], - duration=float(alignments[1] - alignments[0]), - token=words[1], - conf=None, - type_of_token='lex', - speaker=speaker_id, - )) - assert ctm_list[1] == (alignments[2], get_ctm_line( - source=session_name, - channel="1", - start_time=alignments[2], - duration=float(alignments[2] - alignments[1]), - token=words[2], - conf=None, - type_of_token='lex', - speaker=speaker_id, - )) + assert ctm_list[0] == ( + alignments[1], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[1], + duration=float(alignments[1] - alignments[0]), + token=words[1], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) + assert ctm_list[1] == ( + alignments[2], + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[2], + duration=float(alignments[2] - alignments[1]), + token=words[2], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), + ) class TestSpeechSampler: From 1166231ad8659557baf520ee6a46cde8fd1c2b8e Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 18 Dec 2023 15:40:35 -0800 Subject: [PATCH 19/22] Fixed typos and errors in manifest_utils.py Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/manifest_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index a9e868bd0faa..918162320519 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -36,8 +36,8 @@ def get_rounded_str_float( num: float, output_precision: int, - min_precision=1, - max_precision=3 + min_precision: int=1, + max_precision: int=3, )-> str: """ Get a string of a float number with rounded precision. @@ -69,7 +69,7 @@ def get_ctm_line( output_precision: int = 2, ) -> str: """ - Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. CTM Format: @@ -108,7 +108,7 @@ def get_ctm_line( if type(conf) == str and conf.replace('.', '', 1).isdigit(): conf = float(conf) - elif type(duration) != float: + elif type(conf) != float: raise ValueError(f"`conf` must be a float or str containing float, but got {type(conf)}") if channel is not None and type(channel) != int: From 4c2421fa0152ebe1771a7301545c78cead7b99ae Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 18 Dec 2023 15:43:06 -0800 Subject: [PATCH 20/22] Resolved another merge conflict Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/manifest_utils.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 2716631a70f3..8eae73a91143 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -33,16 +33,7 @@ from nemo.utils.data_utils import DataStoreObject -<<<<<<< HEAD -def get_rounded_str_float( - num: float, - output_precision: int, - min_precision: int=1, - max_precision: int=3, -)-> str: -======= def get_rounded_str_float(num: float, output_precision: int, min_precision=1, max_precision=3) -> str: ->>>>>>> 2eaa24b17589a05482f2d4f8a511ad87d8c3c2c9 """ Get a string of a float number with rounded precision. From a56d047c6ace436fe3b36f2b5c06a93e6f7ce1fe Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Thu, 4 Jan 2024 16:51:21 -0800 Subject: [PATCH 21/22] Fixed the test errors Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/manifest_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 8eae73a91143..930a085cf1f2 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -104,19 +104,21 @@ def get_ctm_line( if type(conf) == str and conf.replace('.', '', 1).isdigit(): conf = float(conf) + elif conf is None: + conf = NA_token elif type(conf) != float: raise ValueError(f"`conf` must be a float or str containing float, but got {type(conf)}") if channel is not None and type(channel) != int: channel = str(channel) - if conf is not None and type(conf) != float: - raise ValueError(f"`conf` must be a float, but got {type(conf)}") - if conf is not None and not (0 <= conf <= 1): + # if conf is not None and type(conf) != float: + # raise ValueError(f"`conf` must be a float, but got {type(conf)} type {conf}") + if conf is not None and type(conf) == float and not (0 <= conf <= 1): raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") if type_of_token is not None and type(type_of_token) != str: - raise ValueError(f"`type` must be a string, but got {type(type)}") + raise ValueError(f"`type` must be a string, but got {type(type_of_token)} type {type_of_token}") if type_of_token is not None and type_of_token not in VALID_TOKEN_TYPES: - raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token}") + raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token} type {type_of_token}") if speaker is not None and type(speaker) != str: raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") From 851f7161f47c824c21cf5c1691daeae2992e62f7 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Fri, 5 Jan 2024 12:07:20 -0800 Subject: [PATCH 22/22] Fixed the missed commented lines Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/manifest_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 930a085cf1f2..71a35ceb3426 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -111,8 +111,6 @@ def get_ctm_line( if channel is not None and type(channel) != int: channel = str(channel) - # if conf is not None and type(conf) != float: - # raise ValueError(f"`conf` must be a float, but got {type(conf)} type {conf}") if conf is not None and type(conf) == float and not (0 <= conf <= 1): raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") if type_of_token is not None and type(type_of_token) != str: