diff --git a/docs/source/asr/speaker_diarization/datasets.rst b/docs/source/asr/speaker_diarization/datasets.rst index 9f1a43a58f11..952f426b4ff8 100644 --- a/docs/source/asr/speaker_diarization/datasets.rst +++ b/docs/source/asr/speaker_diarization/datasets.rst @@ -205,14 +205,14 @@ The following are descriptions about each field in an input manifest JSON file. ``ctm_filepath`` (Optional): - CTM file is used for the evaluation of word-level diarization results and word-timestamp alignment. CTM file follows the following convention: `` `` Since confidence is not required for evaluating diarization results, it can have any value. Note that the ```` should be exactly matched with speaker IDs in RTTM. + The CTM file is used for the evaluation of word-level diarization results and word-timestamp alignment. The CTM file follows this convention: `` ``. Note that the ```` should exactly match speaker IDs in RTTM. Since confidence is not required for evaluating diarization results, we assign ```` the value ``NA``. If the type of token is words, we assign ```` as ``lex``. Example lines of CTM file: .. code-block:: bash - TS3012d.Mix-Headset MTD046ID 12.879 0.32 okay 0 - TS3012d.Mix-Headset MTD046ID 13.203 0.24 yeah 0 + TS3012d.Mix-Headset 1 12.879 0.32 okay NA lex MTD046ID + TS3012d.Mix-Headset 1 13.203 0.24 yeah NA lex MTD046ID Evaluation on Benchmark Datasets diff --git a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py index da9671a2fc43..d15adb5b826a 100644 --- a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py +++ b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_with_asr_infer.py @@ -63,8 +63,6 @@ def main(cfg): # If RTTM is provided and DER evaluation if diar_score is not None: - metric, mapping_dict, _ = diar_score - # Get session-level diarization error rate and speaker counting error der_results = OfflineDiarWithASR.gather_eval_results( diar_score=diar_score, diff --git a/nemo/collections/asr/parts/utils/data_simulation_utils.py b/nemo/collections/asr/parts/utils/data_simulation_utils.py index a9a1e10ae385..66b21c2478a0 100644 --- a/nemo/collections/asr/parts/utils/data_simulation_utils.py +++ b/nemo/collections/asr/parts/utils/data_simulation_utils.py @@ -25,7 +25,13 @@ from nemo.collections.asr.parts.preprocessing.perturb import AudioAugmentor from nemo.collections.asr.parts.preprocessing.segment import AudioSegment -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest, write_text +from nemo.collections.asr.parts.utils.manifest_utils import ( + get_ctm_line, + read_manifest, + write_ctm, + write_manifest, + write_text, +) from nemo.collections.asr.parts.utils.speaker_utils import labels_to_rttmfile from nemo.utils import logging @@ -774,7 +780,16 @@ def create_new_ctm_entry( prev_align = 0 if i == 0 else alignments[i - 1] align1 = round(float(prev_align + start), self._params.data_simulator.outputs.output_precision) align2 = round(float(alignments[i] - prev_align), self._params.data_simulator.outputs.output_precision) - text = f"{session_name} {speaker_id} {align1} {align2} {word} 0\n" + text = get_ctm_line( + source=session_name, + channel=1, + start_time=align1, + duration=align2, + token=word, + conf=None, + type_of_token='lex', + speaker=speaker_id, + ) arr.append((align1, text)) return arr diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index 39bd6a8a24e7..71a35ceb3426 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -33,6 +33,103 @@ from nemo.utils.data_utils import DataStoreObject +def get_rounded_str_float(num: float, output_precision: int, min_precision=1, max_precision=3) -> str: + """ + Get a string of a float number with rounded precision. + + Args: + num (float): float number to round + output_precision (int): precision of the output floating point number + min_precision (int, optional): Minimum precision of the output floating point number. Defaults to 1. + max_precision (int, optional): Maximum precision of the output floating point number. Defaults to 3. + + Returns: + (str): Return a string of a float number with rounded precision. + """ + output_precision = min(max_precision, max(min_precision, output_precision)) + return f"{num:.{output_precision}f}" + + +def get_ctm_line( + source: str, + channel: int, + start_time: float, + duration: float, + token: str, + conf: float, + type_of_token: str, + speaker: str, + NA_token: str = 'NA', + UNK: str = 'unknown', + default_channel: str = '1', + output_precision: int = 2, +) -> str: + """ + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + + CTM Format: + + + Reference: + https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf + + Args: + source (str): is name of the source file, session name or utterance ID + channel (int): is channel number defaults to 1 + start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. + duration (float): is duration of the word + token (str): Token or word for the current entry + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. + type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when + the speaker has not been determined. + NA_token (str, optional): A token for . Defaults to ''. + output_precision (int, optional): The precision of the output floating point number. Defaults to 3. + + Returns: + str: Return a line in CTM format filled with the given information. + """ + VALID_TOKEN_TYPES = ["lex", "frag", "fp", "un-lex", "for-lex", "non-lex", "misc", "noscore"] + + if type(start_time) == str and start_time.replace('.', '', 1).isdigit(): + start_time = float(start_time) + elif type(start_time) != float: + raise ValueError(f"`start_time` must be a float or str containing float, but got {type(start_time)}") + + if type(duration) == str and duration.replace('.', '', 1).isdigit(): + duration = float(duration) + elif type(duration) != float: + raise ValueError(f"`duration` must be a float or str containing float, but got {type(duration)}") + + if type(conf) == str and conf.replace('.', '', 1).isdigit(): + conf = float(conf) + elif conf is None: + conf = NA_token + elif type(conf) != float: + raise ValueError(f"`conf` must be a float or str containing float, but got {type(conf)}") + + if channel is not None and type(channel) != int: + channel = str(channel) + if conf is not None and type(conf) == float and not (0 <= conf <= 1): + raise ValueError(f"`conf` must be between 0 and 1, but got {conf}") + if type_of_token is not None and type(type_of_token) != str: + raise ValueError(f"`type` must be a string, but got {type(type_of_token)} type {type_of_token}") + if type_of_token is not None and type_of_token not in VALID_TOKEN_TYPES: + raise ValueError(f"`type` must be one of {VALID_TOKEN_TYPES}, but got {type_of_token} type {type_of_token}") + if speaker is not None and type(speaker) != str: + raise ValueError(f"`speaker` must be a string, but got {type(speaker)}") + + channel = default_channel if channel is None else channel + conf = NA_token if conf is None else conf + speaker = NA_token if speaker is None else speaker + type_of_token = UNK if type_of_token is None else type_of_token + start_time = get_rounded_str_float(start_time, output_precision) + duration = get_rounded_str_float(duration, output_precision) + conf = get_rounded_str_float(conf, output_precision) if conf != NA_token else conf + return f"{source} {channel} {start_time} {duration} {token} {conf} {type_of_token} {speaker}\n" + + def rreplace(s: str, old: str, new: str) -> str: """ Replace end of string. diff --git a/scripts/speaker_tasks/create_alignment_manifest.py b/scripts/speaker_tasks/create_alignment_manifest.py index e2b15b03b842..91825ac0d7e9 100644 --- a/scripts/speaker_tasks/create_alignment_manifest.py +++ b/scripts/speaker_tasks/create_alignment_manifest.py @@ -16,12 +16,41 @@ import os import shutil from pathlib import Path +from typing import List -from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_ctm, write_manifest +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line, read_manifest, write_ctm, write_manifest from nemo.utils import logging -def get_unaligned_files(unaligned_path): +def get_seg_info_from_ctm_line( + ctm_list: List[str], + output_precision: int, + speaker_index: int = 7, + start_time_index: int = 2, + duration_index: int = 3, +): + """ + Get time stamp information and speaker labels from CTM lines. + This is following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. + + Args: + ctm_list (list): List containing CTM items. e.g.: ['sw02001-A', '1', '0.000', '0.200', 'hello', '0.98', 'lex', 'speaker3'] + output_precision (int): Precision for CTM outputs in integer. + + Returns: + start (float): Start time of the segment. + end (float): End time of the segment. + speaker_id (str): Speaker ID of the segment. + """ + speaker_id = ctm_list[speaker_index] + start = float(ctm_list[start_time_index]) + end = float(ctm_list[start_time_index]) + float(ctm_list[duration_index]) + start = round(start, output_precision) + end = round(end, output_precision) + return start, end, speaker_id + + +def get_unaligned_files(unaligned_path: str) -> List[str]: """ Get files without alignments in order to filter them out (as they cannot be used for data simulation). In the unaligned file, each line contains the file name and the reason for the unalignment, if necessary to specify. @@ -71,7 +100,17 @@ def create_new_ctm_entry(session_name, speaker_id, wordlist, alignments, output_ # note that using the current alignments the first word is always empty, so there is no error from indexing the array with i-1 align1 = float(round(alignments[i - 1], output_precision)) align2 = float(round(alignments[i] - alignments[i - 1], output_precision,)) - text = f"{session_name} {speaker_id} {align1} {align2} {word} 0\n" + text = get_ctm_line( + source=session_name, + channel=speaker_id, + start_time=align1, + duration=align2, + token=word, + conf=0, + type_of_token='lex', + speaker=speaker_id, + output_precision=output_precision, + ) arr.append((align1, text)) return arr @@ -206,11 +245,7 @@ def create_manifest_with_alignments( prev_end = 0 for i in range(len(lines)): ctm = lines[i].split(' ') - speaker_id = ctm[1] - start = float(ctm[2]) - end = float(ctm[2]) + float(ctm[3]) - start = round(start, output_precision) - end = round(end, output_precision) + speaker_id, start, end = get_seg_info_from_ctm_line(ctm_list=ctm, output_precision=output_precision) interval = start - prev_end if (i == 0 and interval > 0) or (i > 0 and interval > silence_dur_threshold): @@ -231,13 +266,16 @@ def create_manifest_with_alignments( end_times.append(f['duration']) # build target manifest entry - target_manifest.append({}) - target_manifest[tgt_i]['audio_filepath'] = f['audio_filepath'] - target_manifest[tgt_i]['duration'] = f['duration'] - target_manifest[tgt_i]['text'] = f['text'] - target_manifest[tgt_i]['words'] = words - target_manifest[tgt_i]['alignments'] = end_times - target_manifest[tgt_i]['speaker_id'] = speaker_id + target_manifest.append( + { + 'audio_filepath': f['audio_filepath'], + 'duration': f['duration'], + 'text': f['text'], + 'words': words, + 'alignments': end_times, + 'speaker_id': speaker_id, + } + ) src_i += 1 tgt_i += 1 diff --git a/tests/collections/asr/utils/test_data_simul_utils.py b/tests/collections/asr/utils/test_data_simul_utils.py index 4592043248ae..295b79c76d18 100644 --- a/tests/collections/asr/utils/test_data_simul_utils.py +++ b/tests/collections/asr/utils/test_data_simul_utils.py @@ -29,6 +29,7 @@ normalize_audio, read_noise_manifest, ) +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line @pytest.fixture() @@ -129,6 +130,164 @@ def generate_words_and_alignments(sample_index): return words, alignments, speaker_id +class TestGetCtmLine: + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0, 1]) + def test_wrong_type_conf_values(self, conf): + # Test with wrong integer confidence values + with pytest.raises(ValueError): + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + expected = f"test_source 1 0.12 0.46 word {conf} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [0.0, 0.5, 1.0, 0.01, 0.99]) + def test_valid_conf_values(self, conf): + # Test with valid confidence values + output_precision = 2 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected = "test_source 1 0.12 0.46 word" + f" {conf:.{output_precision}f} lex speaker1\n" + assert result == expected, f"Failed on valid conf value {conf}" + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, 2, -1, 100, -100]) + def test_invalid_conf_ranges(self, conf): + # Test with invalid confidence values + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration, output_precision", + [(0.123, 0.456, 2), (1.0, 2.0, 1), (0.0, 0.0, 2), (0.01, 0.99, 3), (1.23, 4.56, 2)], + ) + def test_valid_start_time_duration_with_precision(self, start_time, duration, output_precision): + # Test with valid beginning time, duration values and output precision + confidence = 0.5 + result = get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=confidence, + type_of_token="lex", + speaker="speaker1", + output_precision=output_precision, + ) + expected_start_time = ( + f"{start_time:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected_duration = f"{duration:.{output_precision}f}" # Adjusted to match the output format with precision + expected_confidence = ( + f"{confidence:.{output_precision}f}" # Adjusted to match the output format with precision + ) + expected = f"test_source 1 {expected_start_time} {expected_duration} word {expected_confidence} lex speaker1\n" + assert ( + result == expected + ), f"Failed on valid start_time {start_time}, duration {duration} with precision {output_precision}" + + @pytest.mark.unit + def test_valid_input(self): + # Test with completely valid inputs + result = get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=0.789, + type_of_token="lex", + speaker="speaker1", + ) + expected = "test_source 1 0.12 0.46 word 0.79 lex speaker1\n" + assert result == expected, "Failed on valid input" + + @pytest.mark.unit + @pytest.mark.parametrize( + "start_time, duration", + [ + ("not a float", 1.0), + (1.0, "not a float"), + (1, 2.0), # Integers should be converted to float + (2.0, 3), # Same as above + ], + ) + def test_invalid_types_for_time_duration(self, start_time, duration): + # Test with invalid types for start_time and duration + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=start_time, + duration=duration, + token="word", + conf=0.5, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + @pytest.mark.parametrize("conf", [-0.1, 1.1, "not a float"]) + def test_invalid_conf_values(self, conf): + # Test with invalid values for conf + with pytest.raises(ValueError): + get_ctm_line( + source="test_source", + channel=1, + start_time=0.123, + duration=0.456, + token="word", + conf=conf, + type_of_token="lex", + speaker="speaker1", + ) + + @pytest.mark.unit + def test_default_values(self): + # Test with missing optional parameters + result = get_ctm_line( + source="test_source", + channel=None, + start_time=0.123, + duration=0.456, + token="word", + conf=None, + type_of_token=None, + speaker=None, + ) + expected = "test_source 1 0.12 0.46 word NA unknown NA\n" + assert result == expected, "Failed on default values" + + class TestDataSimulatorUtils: # TODO: add tests for all util functions @pytest.mark.parametrize("max_audio_read_sec", [2.5, 3.5, 4.5]) @@ -246,18 +405,35 @@ def test_create_new_json_entry(self, annotator): def test_create_new_ctm_entry(self, annotator): words, alignments, speaker_id = generate_words_and_alignments(sample_index=0) - start = alignments[0] session_name = 'test_session' ctm_list = annotator.create_new_ctm_entry( - words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=start + words=words, alignments=alignments, session_name=session_name, speaker_id=speaker_id, start=alignments[0] ) assert ctm_list[0] == ( alignments[1], - f"{session_name} {speaker_id} {alignments[1]} {alignments[1]-alignments[0]} {words[1]} 0\n", + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[1], + duration=float(alignments[1] - alignments[0]), + token=words[1], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), ) assert ctm_list[1] == ( alignments[2], - f"{session_name} {speaker_id} {alignments[2]} {alignments[2]-alignments[1]} {words[2]} 0\n", + get_ctm_line( + source=session_name, + channel="1", + start_time=alignments[2], + duration=float(alignments[2] - alignments[1]), + token=words[2], + conf=None, + type_of_token='lex', + speaker=speaker_id, + ), ) diff --git a/tools/nemo_forced_aligner/utils/make_ctm_files.py b/tools/nemo_forced_aligner/utils/make_ctm_files.py index f0326c07cf8f..17b734f20595 100644 --- a/tools/nemo_forced_aligner/utils/make_ctm_files.py +++ b/tools/nemo_forced_aligner/utils/make_ctm_files.py @@ -17,6 +17,7 @@ import soundfile as sf from utils.constants import BLANK_TOKEN, SPACE_TOKEN from utils.data_prep import Segment, Word +from nemo.collections.asr.parts.utils.manifest_utils import get_ctm_line def make_ctm_files( @@ -105,7 +106,17 @@ def make_ctm( # replace any spaces with so we dont introduce extra space characters to our CTM files text = text.replace(" ", SPACE_TOKEN) - f_ctm.write(f"{utt_obj.utt_id} 1 {start_time:.2f} {end_time - start_time:.2f} {text}\n") + ctm_line = get_ctm_line( + source=utt_obj.utt_id, + channel=1, + start_time=start_time, + duration=end_time - start_time, + token=text, + conf=None, + type_of_token='lex', + speaker=None, + ) + f_ctm.write(ctm_line) utt_obj.saved_output_files[f"{alignment_level}_level_ctm_filepath"] = os.path.join( output_dir, f"{utt_obj.utt_id}.ctm" diff --git a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb index ea943b35e0d0..0fb2b62610a6 100644 --- a/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb +++ b/tutorials/speaker_tasks/ASR_with_SpeakerDiarization.ipynb @@ -537,11 +537,11 @@ "source": [ "We also need CTM files as reference transcripts. The columns of a CTM file indicate the following items:\n", "\n", - "` `\n", + "` `\n", "- Example:\n", - "`diar_session_123 speaker_3 13.2 0.25 hi 0`\n", + "`diar_session_123 1 13.2 0.25 hi 0 lex speaker_3`\n", "\n", - "For the purpose of creating the reference annotations, we use `` for speaker labels and assign 0 to ``. The reference CTM file for the `an4_diarize_test.wav` looks like the following:" + "For the purpose of creating the reference annotations, we set `1` for `` and assign \"`NA`\" to ``, \"`lex`\" to ``. The reference CTM file for the `an4_diarize_test.wav` looks like the following:" ] }, { @@ -551,16 +551,16 @@ "outputs": [], "source": [ "an4_diarize_test_ctm = \\\n", - "[\"an4_diarize_test speaker_0 0.4 0.51 eleven 0\",\n", - "\"an4_diarize_test speaker_0 0.95 0.32 twenty 0\",\n", - "\"an4_diarize_test speaker_0 1.35 0.55 seven 0\",\n", - "\"an4_diarize_test speaker_0 1.96 0.32 fifty 0\",\n", - "\"an4_diarize_test speaker_0 2.32 0.75 seven 0\",\n", - "\"an4_diarize_test speaker_1 3.12 0.42 october 0\",\n", - "\"an4_diarize_test speaker_1 3.6 0.28 twenty 0\",\n", - "\"an4_diarize_test speaker_1 3.95 0.35 four 0\",\n", - "\"an4_diarize_test speaker_1 4.3 0.31 nineteen 0\",\n", - "\"an4_diarize_test speaker_1 4.65 0.35 seventy 0\"]" + "[\"an4_diarize_test 1 0.4 0.51 eleven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 0.95 0.32 twenty NA lex speaker_0\",\n", + "\"an4_diarize_test 1 1.35 0.55 seven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 1.96 0.32 fifty NA lex speaker_0\",\n", + "\"an4_diarize_test 1 2.32 0.75 seven NA lex speaker_0\",\n", + "\"an4_diarize_test 1 3.12 0.42 october NA lex speaker_1\",\n", + "\"an4_diarize_test 1 3.6 0.28 twenty NA lex speaker_1\",\n", + "\"an4_diarize_test 1 3.95 0.35 four NA lex speaker_1\",\n", + "\"an4_diarize_test 1 4.3 0.31 nineteen NA lex speaker_1\",\n", + "\"an4_diarize_test 1 4.65 0.35 seventy NA lex speaker_1\",]" ] }, { @@ -959,7 +959,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.13" } }, "nbformat": 4,