From a900e62f3ed4ea72fd86ea42b8895261476df53b Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 25 Oct 2023 07:04:31 -0700 Subject: [PATCH] add confidence Signed-off-by: andrusenkoau --- nemo/collections/asr/metrics/rnnt_wer.py | 12 +- nemo/collections/asr/metrics/wer.py | 2 +- .../parts/submodules/ctc_greedy_decoding.py | 35 +++--- .../parts/submodules/rnnt_beam_decoding.py | 1 + .../parts/submodules/rnnt_greedy_decoding.py | 119 +++++++++++------- .../asr/parts/utils/asr_confidence_utils.py | 109 +++++++--------- ...arch_ngram_transducer_wb-ctc_confidence.py | 58 +++++---- 7 files changed, 178 insertions(+), 158 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer.py b/nemo/collections/asr/metrics/rnnt_wer.py index 92bb488a6039..511c0969821d 100644 --- a/nemo/collections/asr/metrics/rnnt_wer.py +++ b/nemo/collections/asr/metrics/rnnt_wer.py @@ -277,7 +277,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) else: self.decoding = greedy_decode.GreedyTDTInfer( @@ -291,7 +291,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) else: self.decoding = greedy_decode.GreedyMultiblankRNNTInfer( @@ -304,7 +304,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) elif self.cfg.strategy == 'greedy_batch': @@ -320,7 +320,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) else: self.decoding = greedy_decode.GreedyBatchedTDTInfer( @@ -334,7 +334,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) else: @@ -348,7 +348,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None): ), preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) elif self.cfg.strategy == 'beam': diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index a88895763edc..a5ad414dd0d4 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -389,7 +389,7 @@ def __init__(self, decoding_cfg, blank_id: int): preserve_alignments=self.preserve_alignments, compute_timestamps=self.compute_timestamps, preserve_frame_confidence=self.preserve_frame_confidence, - confidence_measure_cfg=self.confidence_measure_cfg, + confidence_method_cfg=self.confidence_method_cfg, ) elif self.cfg.strategy == 'beam': diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 1f29a511fc9c..9b195993a0a7 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -19,7 +19,8 @@ from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.parts.utils import rnnt_utils -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin +# from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType from nemo.utils import logging @@ -55,7 +56,7 @@ def _states_to_device(dec_state, device='cpu'): return dec_state -class GreedyCTCInfer(Typing, ConfidenceMeasureMixin): +class GreedyCTCInfer(Typing, ConfidenceMethodMixin): """A greedy CTC decoder. Provides a common abstraction for sample level and batch level greedy decoding. @@ -71,7 +72,7 @@ class GreedyCTCInfer(Typing, ConfidenceMeasureMixin): preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding. When set to true, the Hypothesis will contain the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. - confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame + confidence_method_cfg: A dict-like object which contains the measure name and settings to compute per-frame confidence scores. name: The measure name (str). @@ -79,7 +80,7 @@ class GreedyCTCInfer(Typing, ConfidenceMeasureMixin): - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. Supported values: - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). @@ -130,7 +131,7 @@ def __init__( preserve_alignments: bool = False, compute_timestamps: bool = False, preserve_frame_confidence: bool = False, - confidence_measure_cfg: Optional[DictConfig] = None, + confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__() @@ -141,7 +142,7 @@ def __init__( self.preserve_frame_confidence = preserve_frame_confidence # set confidence calculation measure - self._init_confidence_measure(confidence_measure_cfg) + self._init_confidence_method(confidence_method_cfg) @typecheck() def forward( @@ -253,27 +254,27 @@ class GreedyCTCInferConfig: preserve_alignments: bool = False compute_timestamps: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() - confidence_method_cfg: str = "DEPRECATED" + confidence_method_cfg: Optional[ConfidenceMethodConfig] = ConfidenceMethodConfig() + # confidence_method_cfg: str = "DEPRECATED" def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( - self.confidence_measure_cfg - if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + self.confidence_method_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) + else ConfidenceMethodConfig(**self.confidence_method_cfg) ) if self.confidence_method_cfg != "DEPRECATED": logging.warning( "`confidence_method_cfg` is deprecated and will be removed in the future. " - "Please use `confidence_measure_cfg` instead." + "Please use `confidence_method_cfg` instead." ) # TODO (alaptev): delete the following two lines sometime in the future - logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") + logging.warning("Re-writing `confidence_method_cfg` with the value of `confidence_method_cfg`.") # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( + self.confidence_method_cfg = OmegaConf.structured( self.confidence_method_cfg - if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_method_cfg) + if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) + else ConfidenceMethodConfig(**self.confidence_method_cfg) ) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 8783a76d8eef..cb512bc5badc 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -1566,6 +1566,7 @@ class BeamRNNTInferConfig: language_model: Optional[Dict[str, Any]] = None softmax_temperature: float = 1.0 preserve_alignments: bool = False + preserve_frame_confidence: bool = True ngram_lm_model: Optional[str] = None ngram_lm_alpha: Optional[float] = 0.0 hat_subtract_ilm: bool = False diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 9cd1c3632ce6..24fc574ea1f3 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -26,7 +26,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional, Tuple, Union import numpy as np @@ -35,7 +35,8 @@ from nemo.collections.asr.modules import rnnt_abstract from nemo.collections.asr.parts.utils import rnnt_utils -from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin +# from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.collections.common.parts.rnn import label_collate from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType @@ -69,7 +70,7 @@ def _states_to_device(dec_state, device='cpu'): return dec_state -class _GreedyRNNTInfer(Typing, ConfidenceMeasureMixin): +class _GreedyRNNTInfer(Typing, ConfidenceMethodMixin): """A greedy transducer decoder. Provides a common abstraction for sample level and batch level greedy decoding. @@ -154,7 +155,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_measure_cfg: Optional[DictConfig] = None, + confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__() self.decoder = decoder_model @@ -167,7 +168,7 @@ def __init__( self.preserve_frame_confidence = preserve_frame_confidence # set confidence calculation measure - self._init_confidence_measure(confidence_measure_cfg) + self._init_confidence_method(confidence_method_cfg) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -544,7 +545,7 @@ def __init__( max_symbols_per_step: Optional[int] = None, preserve_alignments: bool = False, preserve_frame_confidence: bool = False, - confidence_measure_cfg: Optional[DictConfig] = None, + confidence_method_cfg: Optional[DictConfig] = None, ): super().__init__( decoder_model=decoder_model, @@ -553,7 +554,7 @@ def __init__( max_symbols_per_step=max_symbols_per_step, preserve_alignments=preserve_alignments, preserve_frame_confidence=preserve_frame_confidence, - confidence_measure_cfg=confidence_measure_cfg, + confidence_method_cfg=confidence_method_cfg, ) # Depending on availability of `blank_as_pad` support @@ -2204,36 +2205,51 @@ def _greedy_decode_masked( return hypotheses +# @dataclass +# class GreedyRNNTInferConfig: +# max_symbols_per_step: Optional[int] = 10 +# preserve_alignments: bool = False +# preserve_frame_confidence: bool = False +# confidence_measure_cfg: Optional[ConfidenceMethodConfig] = ConfidenceMethodConfig() +# confidence_method_cfg: str = "DEPRECATED" + +# def __post_init__(self): +# # OmegaConf.structured ensures that post_init check is always executed +# self.confidence_measure_cfg = OmegaConf.structured( +# self.confidence_measure_cfg +# if isinstance(self.confidence_measure_cfg, ConfidenceMethodConfig) +# else ConfidenceMethodConfig(**self.confidence_measure_cfg) +# ) +# if self.confidence_method_cfg != "DEPRECATED": +# logging.warning( +# "`confidence_method_cfg` is deprecated and will be removed in the future. " +# "Please use `confidence_measure_cfg` instead." +# ) + +# # TODO (alaptev): delete the following two lines sometime in the future +# logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") +# # OmegaConf.structured ensures that post_init check is always executed +# self.confidence_measure_cfg = OmegaConf.structured( +# self.confidence_method_cfg +# if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) +# else ConfidenceMethodConfig(**self.confidence_method_cfg) +# ) +# self.confidence_method_cfg = "DEPRECATED" + @dataclass class GreedyRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() - confidence_method_cfg: str = "DEPRECATED" + confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( - self.confidence_measure_cfg - if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + self.confidence_method_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) + else ConfidenceMethodConfig(**self.confidence_method_cfg) ) - if self.confidence_method_cfg != "DEPRECATED": - logging.warning( - "`confidence_method_cfg` is deprecated and will be removed in the future. " - "Please use `confidence_measure_cfg` instead." - ) - - # TODO (alaptev): delete the following two lines sometime in the future - logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") - # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( - self.confidence_method_cfg - if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_method_cfg) - ) - self.confidence_method_cfg = "DEPRECATED" @dataclass @@ -2241,31 +2257,38 @@ class GreedyBatchedRNNTInferConfig: max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False - confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig() - confidence_method_cfg: str = "DEPRECATED" + confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( - self.confidence_measure_cfg - if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_measure_cfg) + self.confidence_method_cfg = OmegaConf.structured( + self.confidence_method_cfg + if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) + else ConfidenceMethodConfig(**self.confidence_method_cfg) ) - if self.confidence_method_cfg != "DEPRECATED": - logging.warning( - "`confidence_method_cfg` is deprecated and will be removed in the future. " - "Please use `confidence_measure_cfg` instead." - ) - # TODO (alaptev): delete the following two lines sometime in the future - logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") - # OmegaConf.structured ensures that post_init check is always executed - self.confidence_measure_cfg = OmegaConf.structured( - self.confidence_method_cfg - if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.confidence_method_cfg) - ) - self.confidence_method_cfg = "DEPRECATED" + # def __post_init__(self): + # # OmegaConf.structured ensures that post_init check is always executed + # self.confidence_measure_cfg = OmegaConf.structured( + # self.confidence_measure_cfg + # if isinstance(self.confidence_measure_cfg, ConfidenceMethodConfig) + # else ConfidenceMethodConfig(**self.confidence_measure_cfg) + # ) + # if self.confidence_method_cfg != "DEPRECATED": + # logging.warning( + # "`confidence_method_cfg` is deprecated and will be removed in the future. " + # "Please use `confidence_measure_cfg` instead." + # ) + + # # TODO (alaptev): delete the following two lines sometime in the future + # logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.") + # # OmegaConf.structured ensures that post_init check is always executed + # self.confidence_measure_cfg = OmegaConf.structured( + # self.confidence_method_cfg + # if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig) + # else ConfidenceMethodConfig(**self.confidence_method_cfg) + # ) + # self.confidence_method_cfg = "DEPRECATED" class GreedyTDTInfer(_GreedyRNNTInfer): diff --git a/nemo/collections/asr/parts/utils/asr_confidence_utils.py b/nemo/collections/asr/parts/utils/asr_confidence_utils.py index 29c49529a509..39bf3c8b505d 100644 --- a/nemo/collections/asr/parts/utils/asr_confidence_utils.py +++ b/nemo/collections/asr/parts/utils/asr_confidence_utils.py @@ -14,7 +14,7 @@ import math from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from typing import List, Optional @@ -25,7 +25,7 @@ from nemo.utils import logging -class ConfidenceMeasureConstants: +class ConfidenceMethodConstants: NAMES = ("max_prob", "entropy") ENTROPY_TYPES = ("gibbs", "tsallis", "renyi") ENTROPY_NORMS = ("lin", "exp") @@ -48,17 +48,17 @@ def print(cls): @dataclass -class ConfidenceMeasureConfig: - """A Config which contains the measure name and settings to compute per-frame confidence scores. +class ConfidenceMethodConfig: + """A Config which contains the method name and settings to compute per-frame confidence scores. Args: - name: The measure name (str). + name: The method name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. entropy_type: Which type of entropy to use (str). - Used if confidence_measure_cfg.name is set to `entropy`. + Used if confidence_method_cfg.name is set to `entropy`. Supported values: - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). @@ -92,31 +92,25 @@ class ConfidenceMeasureConfig: def __post_init__(self): if self.temperature != "DEPRECATED": - logging.warning( - "`temperature` is deprecated and will be removed in the future. Please use `alpha` instead." - ) - - # TODO (alaptev): delete the following two lines sometime in the future - logging.warning("Re-writing `alpha` with the value of `temperature`.") # self.temperature has type str self.alpha = float(self.temperature) self.temperature = "DEPRECATED" - if self.name not in ConfidenceMeasureConstants.NAMES: + if self.name not in ConfidenceMethodConstants.NAMES: raise ValueError( f"`name` must be one of the following: " - f"{'`' + '`, `'.join(ConfidenceMeasureConstants.NAMES) + '`'}. Provided: `{self.name}`" + f"{'`' + '`, `'.join(ConfidenceMethodConstants.NAMES) + '`'}. Provided: `{self.name}`" ) - if self.entropy_type not in ConfidenceMeasureConstants.ENTROPY_TYPES: + if self.entropy_type not in ConfidenceMethodConstants.ENTROPY_TYPES: raise ValueError( f"`entropy_type` must be one of the following: " - f"{'`' + '`, `'.join(ConfidenceMeasureConstants.ENTROPY_TYPES) + '`'}. Provided: `{self.entropy_type}`" + f"{'`' + '`, `'.join(ConfidenceMethodConstants.ENTROPY_TYPES) + '`'}. Provided: `{self.entropy_type}`" ) if self.alpha <= 0.0: raise ValueError(f"`alpha` must be > 0. Provided: {self.alpha}") - if self.entropy_norm not in ConfidenceMeasureConstants.ENTROPY_NORMS: + if self.entropy_norm not in ConfidenceMethodConstants.ENTROPY_NORMS: raise ValueError( f"`entropy_norm` must be one of the following: " - f"{'`' + '`, `'.join(ConfidenceMeasureConstants.ENTROPY_NORMS) + '`'}. Provided: `{self.entropy_norm}`" + f"{'`' + '`, `'.join(ConfidenceMethodConstants.ENTROPY_NORMS) + '`'}. Provided: `{self.entropy_norm}`" ) @@ -142,15 +136,15 @@ class ConfidenceConfig: from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame + method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. - name: The measure name (str). + name: The method name (str). Supported values: - 'max_prob' for using the maximum token probability as a confidence. - 'entropy' for using a normalized entropy of a log-likelihood vector. - entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`. + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. Supported values: - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). @@ -181,34 +175,19 @@ class ConfidenceConfig: preserve_word_confidence: bool = False exclude_blank: bool = True aggregation: str = "min" - measure_cfg: ConfidenceMeasureConfig = ConfidenceMeasureConfig() - method_cfg: str = "DEPRECATED" + method_cfg: ConfidenceMethodConfig = field(default_factory=lambda: ConfidenceMethodConfig()) def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed - self.measure_cfg = OmegaConf.structured( - self.measure_cfg - if isinstance(self.measure_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.measure_cfg) + self.method_cfg = OmegaConf.structured( + self.method_cfg + if isinstance(self.method_cfg, ConfidenceMethodConfig) + else ConfidenceMethodConfig(**self.method_cfg) ) - if self.method_cfg != "DEPRECATED": - logging.warning( - "`method_cfg` is deprecated and will be removed in the future. Please use `measure_cfg` instead." - ) - - # TODO (alaptev): delete the following two lines sometime in the future - logging.warning("Re-writing `measure_cfg` with the value of `method_cfg`.") - # OmegaConf.structured ensures that post_init check is always executed - self.measure_cfg = OmegaConf.structured( - self.method_cfg - if isinstance(self.method_cfg, ConfidenceMeasureConfig) - else ConfidenceMeasureConfig(**self.method_cfg) - ) - self.method_cfg = "DEPRECATED" if self.aggregation not in ConfidenceConstants.AGGREGATIONS: raise ValueError( f"`aggregation` has to be one of the following: " - f"{'`' + '`, `'.join(ConfidenceMeasureConstants.AGGREGATIONS) + '`'}. Provided: `{self.aggregation}`" + f"{'`' + '`, `'.join(ConfidenceConstants.AGGREGATIONS) + '`'}. Provided: `{self.aggregation}`" ) @@ -284,7 +263,7 @@ def entropy_gibbs_exp(x, v, t): def get_confidence_aggregation_bank(): """Generate a dictionary with confidence aggregation functions. - Supported confidence measures: + Supported confidence aggregation functions: min: minimum max: maximum mean: arithmetic mean @@ -305,26 +284,26 @@ def get_confidence_aggregation_bank(): return confidence_aggregation_bank -class ConfidenceMeasureMixin(ABC): - """Confidence Measure Mixin class. +class ConfidenceMethodMixin(ABC): + """Confidence Method Mixin class. - It initializes per-frame confidence measure. + It initializes per-frame confidence method. """ - def _init_confidence_measure(self, confidence_measure_cfg: Optional[DictConfig] = None): - """Initialize per-frame confidence measure from config. + def _init_confidence_method(self, confidence_method_cfg: Optional[DictConfig] = None): + """Initialize per-frame confidence method from config. """ # OmegaConf.structured ensures that post_init check is always executed - confidence_measure_cfg = OmegaConf.structured( - ConfidenceMeasureConfig() - if confidence_measure_cfg is None - else ConfidenceMeasureConfig(**confidence_measure_cfg) + confidence_method_cfg = OmegaConf.structured( + ConfidenceMethodConfig() + if confidence_method_cfg is None + else ConfidenceMethodConfig(**confidence_method_cfg) ) - # set confidence calculation measure + # set confidence calculation method # we suppose that self.blank_id == len(vocabulary) self.num_tokens = (self.blank_id if hasattr(self, "blank_id") else self._blank_index) + 1 - self.alpha = confidence_measure_cfg.alpha + self.alpha = confidence_method_cfg.alpha # init confidence measure bank self.confidence_measure_bank = get_confidence_measure_bank() @@ -332,14 +311,14 @@ def _init_confidence_measure(self, confidence_measure_cfg: Optional[DictConfig] measure = None # construct measure_name measure_name = "" - if confidence_measure_cfg.name == "max_prob": + if confidence_method_cfg.name == "max_prob": measure_name = "max_prob" - elif confidence_measure_cfg.name == "entropy": + elif confidence_method_cfg.name == "entropy": measure_name = '_'.join( - [confidence_measure_cfg.name, confidence_measure_cfg.entropy_type, confidence_measure_cfg.entropy_norm] + [confidence_method_cfg.name, confidence_method_cfg.entropy_type, confidence_method_cfg.entropy_norm] ) else: - raise ValueError(f"Unsupported `confidence_measure_cfg.name`: `{confidence_measure_cfg.name}`") + raise ValueError(f"Unsupported `confidence_method_cfg.name`: `{confidence_method_cfg.name}`") if measure_name not in self.confidence_measure_bank: raise ValueError(f"Unsupported measure setup: `{measure_name}`") measure = partial(self.confidence_measure_bank[measure_name], v=self.num_tokens, t=self.alpha) @@ -359,7 +338,7 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): confidence_cfg = OmegaConf.structured( ConfidenceConfig() if confidence_cfg is None else ConfidenceConfig(**confidence_cfg) ) - self.confidence_measure_cfg = confidence_cfg.measure_cfg + self.confidence_method_cfg = confidence_cfg.method_cfg # extract the config self.preserve_word_confidence = confidence_cfg.get('preserve_word_confidence', False) @@ -384,11 +363,11 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None): if self.cfg.strategy in ['greedy', 'greedy_batch']: self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False) # OmegaConf.structured ensures that post_init check is always executed - confidence_measure_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_measure_cfg', None) - self.confidence_measure_cfg = ( - OmegaConf.structured(ConfidenceMeasureConfig()) - if confidence_measure_cfg is None - else OmegaConf.structured(ConfidenceMeasureConfig(**confidence_measure_cfg)) + confidence_method_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_method_cfg', None) + self.confidence_method_cfg = ( + OmegaConf.structured(ConfidenceMethodConfig()) + if confidence_method_cfg is None + else OmegaConf.structured(ConfidenceMethodConfig(**confidence_method_cfg)) ) @abstractmethod @@ -479,4 +458,4 @@ def _aggregate_token_confidence_subwords_sentencepiece( len(word_confidence): {len(word_confidence)},\n recognized text: `{' '.join(words)}`""" ) - return word_confidence + return word_confidence \ No newline at end of file diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc_confidence.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc_confidence.py index c877770222c4..6b3762255c74 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc_confidence.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc_confidence.py @@ -155,24 +155,34 @@ class EvalBeamSearchNGramConfig: def merge_alignment_with_wb_hyps( - alignment, + candidate, model, wb_result, ): + alignment = candidate.alignments alignment_per_frame = [] for items in alignment: current_frame_ali = [x[1].item() for x in items] + # logging.warning("-----"*10) + # logging.warning(current_frame_ali) alignment_per_frame.append(current_frame_ali) + frame_confidence = [] + for items in candidate.frame_confidence: + # current_frame_confidence = [x for x in items] + frame_confidence.append(items) + + assert len(alignment_per_frame) == len(frame_confidence) + # alignment = [x[0][1].item() for x in alignment] # get words borders alignment_tokens = [] for idx, frame_ali in enumerate(alignment_per_frame): - for token in frame_ali: + for idy, token in enumerate(frame_ali): if token != model.decoder.blank_idx: - alignment_tokens.append([idx, model.tokenizer.ids_to_tokens([token])[0]]) + alignment_tokens.append([idx, model.tokenizer.ids_to_tokens([token])[0], frame_confidence[idx][idy]]) if not alignment_tokens: return " ".join([wb_hyp.word for wb_hyp in wb_result]) @@ -187,23 +197,26 @@ def merge_alignment_with_wb_hyps( word = item[1][1:] l = item[0] r = item[0] + word_confidence = [item[2]] else: if item[1].startswith(slash): - word_alignment.append((word, l, r)) + word_alignment.append((word, l, r, word_confidence)) word = item[1][1:] l = item[0] r = item[0] + word_confidence = [item[2]] else: word += item[1] r = item[0] - word_alignment.append((word, l, r)) + word_confidence.append(item[2]) + word_alignment.append((word, l, r, word_confidence)) ref_text = [item[0] for item in word_alignment] ref_text = " ".join(ref_text) - print(f"before: {ref_text}") # merge wb_hyps and word alignment: for wb_hyp in wb_result: new_word_alignment = [] + rnnt_spot_info = [] already_pasted = False lh, rh = wb_hyp.start_frame, wb_hyp.end_frame for item in word_alignment: @@ -212,13 +225,17 @@ def merge_alignment_with_wb_hyps( if not already_pasted: new_word_alignment.append((wb_hyp.word, wb_hyp.start_frame, wb_hyp.end_frame)) already_pasted = True + rnnt_spot_info.append([item[0], item[3]]) else: new_word_alignment.append(item) word_alignment = new_word_alignment + print(f"wb_hyp: {wb_hyp.word}") + print(f"spot info: {rnnt_spot_info}") # boosted_text_list = [wb_hyp.word for wb_hyp in new_word_alignment] boosted_text_list = [item[0] for item in new_word_alignment] boosted_text = " ".join(boosted_text_list) + print(f"before: {ref_text}") print(f"after : {boosted_text}") return boosted_text @@ -317,7 +334,7 @@ def decoding_step( # make new text by mearging alignment with ctc-wb predictions: print("----") boosted_text = merge_alignment_with_wb_hyps( - candidate.alignments, + candidate, model, wb_results[audio_file_paths[sample_idx + beams_idx]] ) @@ -585,21 +602,20 @@ def default_autocast(): if cfg.use_confidence: cfg.confidence_cfg = ConfidenceConfig( preserve_frame_confidence=True, # Internally set to true if preserve_token_confidence == True - # or preserve_word_confidence == True - preserve_token_confidence=False, # Internally set to true if preserve_word_confidence == True - preserve_word_confidence=False, - aggregation="min", # How to aggregate frame scores to token scores and token scores to word scores - exclude_blank=True, # If true, only non-blank emissions contribute to confidence scores - method_cfg=ConfidenceMethodConfig( # Config for per-frame scores calculation (before aggregation) - name="max_prob", # Or "entropy" (default), which usually works better - entropy_type="gibbs", # Used only for name == "entropy". Recommended: "tsallis" (default) or "renyi" - alpha=0.5, # Low values (<1) increase sensitivity, high values decrease sensitivity - entropy_norm="lin" # How to normalize (map to [0,1]) entropy. Default: "exp" ) - ) - # logging.warning("-------------") - # logging.warning(f"confidence_cfg is: {cfg.confidence_cfg}") - # raise KeyError + # cfg.confidence_cfg = ConfidenceConfig( + # preserve_frame_confidence=True, # Internally set to true if preserve_token_confidence == True + # # or preserve_word_confidence == True + # preserve_token_confidence=False, # Internally set to true if preserve_word_confidence == True + # preserve_word_confidence=False, + # aggregation="min", # How to aggregate frame scores to token scores and token scores to word scores + # exclude_blank=True, # If true, only non-blank emissions contribute to confidence scores + # method_cfg=ConfidenceMethodConfig( # Config for per-frame scores calculation (before aggregation) + # name="max_prob", # Or "entropy" (default), which usually works better + # entropy_type="gibbs", # Used only for name == "entropy". Recommended: "tsallis" (default) or "renyi" + # alpha=0.5, # Low values (<1) increase sensitivity, high values decrease sensitivity + # entropy_norm="lin" # How to normalize (map to [0,1]) entropy. Default: "exp" + # ) ################################