Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] Support CTC decoder online #821

Merged
merged 14 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 7 additions & 29 deletions examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,39 +162,17 @@ def forward(self, audio, audio_len, text, text_len):
return loss

@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)

# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)

def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size = batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
Jackwaterveg marked this conversation as resolved.
Show resolved Hide resolved
return trans_best

@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
Expand Down
34 changes: 17 additions & 17 deletions examples/other/1xt2x/src_deepspeech2x/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,10 @@ def compute_metrics(self,
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer

vocab_list = self.test_loader.collate_fn.vocab_list

target_transcripts = self.ordid2token(texts, texts_len)

result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
result_transcripts = self.compute_result_transcripts(audio, audio_len)

for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
Expand All @@ -280,19 +278,9 @@ def compute_metrics(self,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)

def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode(audio, audio_len)

result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
Expand All @@ -307,6 +295,17 @@ def test(self):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0

# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)

with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
Expand All @@ -326,6 +325,7 @@ def test(self):
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.model.decoder.del_decoder()

def run_test(self):
self.resume_or_scratch()
Expand Down
5 changes: 5 additions & 0 deletions paddlespeech/s2t/decoders/ctcdecoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .swig_wrapper import ctc_beam_search_decoding
from .swig_wrapper import ctc_beam_search_decoding_batch
from .swig_wrapper import ctc_greedy_decoding
from .swig_wrapper import CTCBeamSearchDecoder
from .swig_wrapper import Scorer
77 changes: 51 additions & 26 deletions paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for various CTC decoders in SWIG."""
import swig_decoders
import paddlespeech_ctcdecoders


class Scorer(swig_decoders.Scorer):
class Scorer(paddlespeech_ctcdecoders.Scorer):
"""Wrapper for Scorer.

:param alpha: Parameter associated with language model. Don't use
Expand All @@ -26,14 +26,17 @@ class Scorer(swig_decoders.Scorer):
:type beta: float
:model_path: Path to load language model.
:type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
"""

def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path,
vocabulary)


def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig.
def ctc_greedy_decoding(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decodeing function in swig.

:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
Expand All @@ -44,19 +47,19 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
:return: Decoding result string.
:rtype: str
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(),
vocabulary, blank_id)
return result


def ctc_beam_search_decoder(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder.
def ctc_beam_search_decoding(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoding function.

:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
Expand All @@ -81,22 +84,22 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results


def ctc_beam_search_decoder_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder.
def ctc_beam_search_decoding_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decodeing batch function.

:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
Expand Down Expand Up @@ -126,9 +129,31 @@ def ctc_beam_search_decoder_batch(probs_split,
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]

batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results


class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list (list): Vocabulary list.
beam_size (int): Width for beam search.
num_processes (int): Number of parallel processes.
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
or language model.
"""

def __init__(self, vocab_list, batch_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, _ext_scorer, blank_id):
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__(
self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, _ext_scorer, blank_id)
Loading