Skip to content

Commit

Permalink
add confidence
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 25, 2023
1 parent 04b7ec7 commit a900e62
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 158 deletions.
12 changes: 6 additions & 6 deletions nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)
else:
self.decoding = greedy_decode.GreedyTDTInfer(
Expand All @@ -291,7 +291,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)
else:
self.decoding = greedy_decode.GreedyMultiblankRNNTInfer(
Expand All @@ -304,7 +304,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == 'greedy_batch':
Expand All @@ -320,7 +320,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)
else:
self.decoding = greedy_decode.GreedyBatchedTDTInfer(
Expand All @@ -334,7 +334,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)

else:
Expand All @@ -348,7 +348,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == 'beam':
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __init__(self, decoding_cfg, blank_id: int):
preserve_alignments=self.preserve_alignments,
compute_timestamps=self.compute_timestamps,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_measure_cfg=self.confidence_measure_cfg,
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == 'beam':
Expand Down
35 changes: 18 additions & 17 deletions nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin
# from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
from nemo.utils import logging
Expand Down Expand Up @@ -55,7 +56,7 @@ def _states_to_device(dec_state, device='cpu'):
return dec_state


class GreedyCTCInfer(Typing, ConfidenceMeasureMixin):
class GreedyCTCInfer(Typing, ConfidenceMethodMixin):
"""A greedy CTC decoder.
Provides a common abstraction for sample level and batch level greedy decoding.
Expand All @@ -71,15 +72,15 @@ class GreedyCTCInfer(Typing, ConfidenceMeasureMixin):
preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores
generated during decoding. When set to true, the Hypothesis will contain
the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats.
confidence_measure_cfg: A dict-like object which contains the measure name and settings to compute per-frame
confidence_method_cfg: A dict-like object which contains the measure name and settings to compute per-frame
confidence scores.
name: The measure name (str).
Supported values:
- 'max_prob' for using the maximum token probability as a confidence.
- 'entropy' for using a normalized entropy of a log-likelihood vector.
entropy_type: Which type of entropy to use (str). Used if confidence_measure_cfg.name is set to `entropy`.
entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`.
Supported values:
- 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
preserve_alignments: bool = False,
compute_timestamps: bool = False,
preserve_frame_confidence: bool = False,
confidence_measure_cfg: Optional[DictConfig] = None,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__()

Expand All @@ -141,7 +142,7 @@ def __init__(
self.preserve_frame_confidence = preserve_frame_confidence

# set confidence calculation measure
self._init_confidence_measure(confidence_measure_cfg)
self._init_confidence_method(confidence_method_cfg)

@typecheck()
def forward(
Expand Down Expand Up @@ -253,27 +254,27 @@ class GreedyCTCInferConfig:
preserve_alignments: bool = False
compute_timestamps: bool = False
preserve_frame_confidence: bool = False
confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig()
confidence_method_cfg: str = "DEPRECATED"
confidence_method_cfg: Optional[ConfidenceMethodConfig] = ConfidenceMethodConfig()
# confidence_method_cfg: str = "DEPRECATED"

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_measure_cfg
if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_measure_cfg)
self.confidence_method_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
else ConfidenceMethodConfig(**self.confidence_method_cfg)
)
if self.confidence_method_cfg != "DEPRECATED":
logging.warning(
"`confidence_method_cfg` is deprecated and will be removed in the future. "
"Please use `confidence_measure_cfg` instead."
"Please use `confidence_method_cfg` instead."
)

# TODO (alaptev): delete the following two lines sometime in the future
logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.")
logging.warning("Re-writing `confidence_method_cfg` with the value of `confidence_method_cfg`.")
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_method_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_method_cfg)
if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
else ConfidenceMethodConfig(**self.confidence_method_cfg)
)
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,7 @@ class BeamRNNTInferConfig:
language_model: Optional[Dict[str, Any]] = None
softmax_temperature: float = 1.0
preserve_alignments: bool = False
preserve_frame_confidence: bool = True
ngram_lm_model: Optional[str] = None
ngram_lm_alpha: Optional[float] = 0.0
hat_subtract_ilm: bool = False
Expand Down
119 changes: 71 additions & 48 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union

import numpy as np
Expand All @@ -35,7 +35,8 @@

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin
# from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMeasureConfig, ConfidenceMeasureMixin
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.rnn import label_collate
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import AcousticEncodedRepresentation, ElementType, HypothesisType, LengthsType, NeuralType
Expand Down Expand Up @@ -69,7 +70,7 @@ def _states_to_device(dec_state, device='cpu'):
return dec_state


class _GreedyRNNTInfer(Typing, ConfidenceMeasureMixin):
class _GreedyRNNTInfer(Typing, ConfidenceMethodMixin):
"""A greedy transducer decoder.
Provides a common abstraction for sample level and batch level greedy decoding.
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_measure_cfg: Optional[DictConfig] = None,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__()
self.decoder = decoder_model
Expand All @@ -167,7 +168,7 @@ def __init__(
self.preserve_frame_confidence = preserve_frame_confidence

# set confidence calculation measure
self._init_confidence_measure(confidence_measure_cfg)
self._init_confidence_method(confidence_method_cfg)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Expand Down Expand Up @@ -544,7 +545,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_measure_cfg: Optional[DictConfig] = None,
confidence_method_cfg: Optional[DictConfig] = None,
):
super().__init__(
decoder_model=decoder_model,
Expand All @@ -553,7 +554,7 @@ def __init__(
max_symbols_per_step=max_symbols_per_step,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_measure_cfg=confidence_measure_cfg,
confidence_method_cfg=confidence_method_cfg,
)

# Depending on availability of `blank_as_pad` support
Expand Down Expand Up @@ -2204,68 +2205,90 @@ def _greedy_decode_masked(
return hypotheses


# @dataclass
# class GreedyRNNTInferConfig:
# max_symbols_per_step: Optional[int] = 10
# preserve_alignments: bool = False
# preserve_frame_confidence: bool = False
# confidence_measure_cfg: Optional[ConfidenceMethodConfig] = ConfidenceMethodConfig()
# confidence_method_cfg: str = "DEPRECATED"

# def __post_init__(self):
# # OmegaConf.structured ensures that post_init check is always executed
# self.confidence_measure_cfg = OmegaConf.structured(
# self.confidence_measure_cfg
# if isinstance(self.confidence_measure_cfg, ConfidenceMethodConfig)
# else ConfidenceMethodConfig(**self.confidence_measure_cfg)
# )
# if self.confidence_method_cfg != "DEPRECATED":
# logging.warning(
# "`confidence_method_cfg` is deprecated and will be removed in the future. "
# "Please use `confidence_measure_cfg` instead."
# )

# # TODO (alaptev): delete the following two lines sometime in the future
# logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.")
# # OmegaConf.structured ensures that post_init check is always executed
# self.confidence_measure_cfg = OmegaConf.structured(
# self.confidence_method_cfg
# if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
# else ConfidenceMethodConfig(**self.confidence_method_cfg)
# )
# self.confidence_method_cfg = "DEPRECATED"

@dataclass
class GreedyRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig()
confidence_method_cfg: str = "DEPRECATED"
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_measure_cfg
if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_measure_cfg)
self.confidence_method_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
else ConfidenceMethodConfig(**self.confidence_method_cfg)
)
if self.confidence_method_cfg != "DEPRECATED":
logging.warning(
"`confidence_method_cfg` is deprecated and will be removed in the future. "
"Please use `confidence_measure_cfg` instead."
)

# TODO (alaptev): delete the following two lines sometime in the future
logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.")
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_method_cfg)
)
self.confidence_method_cfg = "DEPRECATED"


@dataclass
class GreedyBatchedRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_measure_cfg: Optional[ConfidenceMeasureConfig] = ConfidenceMeasureConfig()
confidence_method_cfg: str = "DEPRECATED"
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())

def __post_init__(self):
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_measure_cfg
if isinstance(self.confidence_measure_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_measure_cfg)
self.confidence_method_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
else ConfidenceMethodConfig(**self.confidence_method_cfg)
)
if self.confidence_method_cfg != "DEPRECATED":
logging.warning(
"`confidence_method_cfg` is deprecated and will be removed in the future. "
"Please use `confidence_measure_cfg` instead."
)

# TODO (alaptev): delete the following two lines sometime in the future
logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.")
# OmegaConf.structured ensures that post_init check is always executed
self.confidence_measure_cfg = OmegaConf.structured(
self.confidence_method_cfg
if isinstance(self.confidence_method_cfg, ConfidenceMeasureConfig)
else ConfidenceMeasureConfig(**self.confidence_method_cfg)
)
self.confidence_method_cfg = "DEPRECATED"
# def __post_init__(self):
# # OmegaConf.structured ensures that post_init check is always executed
# self.confidence_measure_cfg = OmegaConf.structured(
# self.confidence_measure_cfg
# if isinstance(self.confidence_measure_cfg, ConfidenceMethodConfig)
# else ConfidenceMethodConfig(**self.confidence_measure_cfg)
# )
# if self.confidence_method_cfg != "DEPRECATED":
# logging.warning(
# "`confidence_method_cfg` is deprecated and will be removed in the future. "
# "Please use `confidence_measure_cfg` instead."
# )

# # TODO (alaptev): delete the following two lines sometime in the future
# logging.warning("Re-writing `confidence_measure_cfg` with the value of `confidence_method_cfg`.")
# # OmegaConf.structured ensures that post_init check is always executed
# self.confidence_measure_cfg = OmegaConf.structured(
# self.confidence_method_cfg
# if isinstance(self.confidence_method_cfg, ConfidenceMethodConfig)
# else ConfidenceMethodConfig(**self.confidence_method_cfg)
# )
# self.confidence_method_cfg = "DEPRECATED"


class GreedyTDTInfer(_GreedyRNNTInfer):
Expand Down
Loading

0 comments on commit a900e62

Please sign in to comment.