Skip to content

Commit

Permalink
add reset_stage for ctcdecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackwaterveg committed Jan 21, 2022
1 parent 5138abc commit d76cfcc
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 163 deletions.
13 changes: 3 additions & 10 deletions examples/other/1xt2x/src_deepspeech2x/models/ds2/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,15 @@ def forward(self, audio, audio_len, text, text_len):
return loss

@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
batch_size = audio.shape[0]
self.decoder.init_decoder(batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
num_processes)

# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
self.decoder.del_decoder()
self.decoder.reset_decoder()
return trans_best

@classmethod
Expand Down
33 changes: 17 additions & 16 deletions examples/other/1xt2x/src_deepspeech2x/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,11 @@ def compute_metrics(self,
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer

vocab_list = self.test_loader.collate_fn.vocab_list

target_transcripts = self.ordid2token(texts, texts_len)

result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
result_transcripts = self.compute_result_transcripts(audio, audio_len)

for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
Expand All @@ -280,19 +279,9 @@ def compute_metrics(self,
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)

def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode(audio, audio_len)

result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
Expand All @@ -307,6 +296,17 @@ def test(self):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0

# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)

with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
Expand All @@ -326,6 +326,7 @@ def test(self):
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.model.decoder.del_decoder()

def run_test(self):
self.resume_or_scratch()
Expand Down
74 changes: 48 additions & 26 deletions paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ class Scorer(paddlespeech_ctcdecoders.Scorer):
:type beta: float
:model_path: Path to load language model.
:type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
"""

def __init__(self, alpha, beta, model_path, vocabulary):
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path,
vocabulary)


def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig.
def ctc_greedy_decoding(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decodeing function in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
Expand All @@ -44,19 +47,19 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
:return: Decoding result string.
:rtype: str
"""
result = paddlespeech_ctcdecoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(),
vocabulary, blank_id)
return result


def ctc_beam_search_decoder(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder.
def ctc_beam_search_decoding(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoding function.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
Expand All @@ -81,22 +84,22 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoder(
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results


def ctc_beam_search_decoder_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder.
def ctc_beam_search_decoding_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decodeing batch function.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
Expand Down Expand Up @@ -126,13 +129,32 @@ def ctc_beam_search_decoder_batch(probs_split,
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]

batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoder_batch(
batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results


def get_ctc_beam_search_decoder_batch_class():
return paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch
class CTC_beam_search_decoder(
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list ([list]): [Vocabulary list.]
beam_size ([int]): [Width for beam search.]
num_processes ([int]): [Number of parallel processes.]
param cutoff_prob ([float]): [Cutoff probability in vocabulary pruning,
default 1.0, no pruning.]
cutoff_top_n ([int]): [Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.]
param ext_scorer ([Scorer]): [External scorer for partially decoded sentence, e.g. word count
or language model.]
"""

def __init__(self, vocab_list, batch_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, _ext_scorer, blank_id):
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__(
self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, _ext_scorer, blank_id)
Loading

0 comments on commit d76cfcc

Please sign in to comment.