Skip to content

Commit

Permalink
fix the destructer problem for prefixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackwaterveg committed Jan 19, 2022
1 parent 4a13361 commit 4756c7d
Show file tree
Hide file tree
Showing 9 changed files with 596 additions and 19 deletions.
13 changes: 13 additions & 0 deletions paddlespeech/s2t/decoders/ctcdecoder/swig_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,16 @@ def ctc_beam_search_decoder_batch(probs_split,
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results


def get_ctc_beam_search_chunk_decoder(vocabulary, batch_size, beam_size,
num_processes, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id):
chunk_decoder = swig_decoders.CtcBeamSearchDecoderBatch(
vocabulary, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
return chunk_decoder


def get_ctc_beam_search_decoder_batch_class():
return swig_decoders.CtcBeamSearchDecoderBatch
32 changes: 27 additions & 5 deletions paddlespeech/s2t/exps/deepspeech2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,9 @@ def test(self):
def compute_result_transcripts(self, audio, audio_len, vocab_list,
decode_cfg):
if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
output_probs, output_lens, batch_trans_list = self.static_forward_online(
audio, audio_len, vocab_list, decode_cfg)
logger.info(batch_trans_list)
elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
Expand All @@ -422,12 +423,17 @@ def compute_result_transcripts(self, audio, audio_len, vocab_list,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)

#replace the <space> with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
for sentence in result_transcripts
]

# check the decode chunk result is same to the original decode result
if self.args.model_type == "online":
assert (batch_trans_list[-1] == result_transcripts[-1])

return result_transcripts

def run_test(self):
Expand All @@ -439,7 +445,11 @@ def run_test(self):
except KeyboardInterrupt:
exit(-1)

def static_forward_online(self, audio, audio_len,
def static_forward_online(self,
audio,
audio_len,
vocab_list,
decode_cfg,
decoder_chunk_size: int=1):
"""
Parameters
Expand Down Expand Up @@ -472,6 +482,7 @@ def static_forward_online(self, audio, audio_len,
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)

batch_trans_list = []
for x, x_len in zip(x_list, x_len_list):
if self.args.enable_auto_log is True:
self.autolog.times.start()
Expand Down Expand Up @@ -504,12 +515,18 @@ def static_forward_online(self, audio, audio_len,
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])

trans_chunk_list = []
probs_chunk_list = []
probs_chunk_lens_list = []
if self.args.enable_auto_log is True:
# record the model preprocessing time
self.autolog.times.stamp()

self.model.init_chunk_decoder(
1, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
Expand Down Expand Up @@ -549,9 +566,14 @@ def static_forward_online(self, audio, audio_len,
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()

self.model.decode_get_next(
probs=output_chunk_probs, probs_len=output_chunk_lens)
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
trans_best, trans_beam = self.model.decode_get_trans()
batch_trans_list.append(trans_best[0])
self.model.del_chunk_decoder()

output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
Expand All @@ -573,7 +595,7 @@ def static_forward_online(self, audio, audio_len,
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
return output_probs, output_lens, batch_trans_list

def static_forward_offline(self, audio, audio_len):
"""
Expand Down
69 changes: 69 additions & 0 deletions paddlespeech/s2t/models/ds2_online/deepspeech2.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def __init__(
batch_average=True, # sum / batch_size
grad_norm_type=ctc_grad_norm_type)

self.chunk_decoder = None

def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss
Expand Down Expand Up @@ -313,6 +315,73 @@ def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)

@paddle.no_grad()
def decode_chunk_by_chunk(self, audio, audio_len, vocab_list,
decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
if self.chunk_decoder is not None:
self.del_chunk_decoder()
batch_size = audio.shape[0]
self.chunk_decoder = self.decoder.get_chunk_decoder(
vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
num_processes, cutoff_prob, cutoff_top_n)
print("audio", audio)
print("audio_len", audio_len)
eouts_chunk_list, eouts_chunk_lens_list, final_state_h_box, final_state_c_box = self.encoder.forward_chunk_by_chunk(
audio, audio_len, decoder_chunk_size=1)
for i, (eouts, eouts_len
) in enumerate(zip(eouts_chunk_list, eouts_chunk_lens_list)):
probs = self.decoder.softmax(eouts)
probs_len = eouts_len
self.decoder.chunk_decoder_next(self.chunk_decoder, probs,
probs_len)
trans_best, trans_beam = self.decoder.chunk_decoder_decode(
self.chunk_decoder)
print("trans_best", trans_best)
return trans_best

@paddle.no_grad()
def init_chunk_decoder(self, batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
if self.chunk_decoder is not None:
self.del_chunk_decoder()
self.chunk_decoder = self.decoder.get_chunk_decoder(
vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
num_processes, cutoff_prob, cutoff_top_n)

@paddle.no_grad()
def decode_get_next(self, probs, probs_len):
if self.chunk_decoder is None:
raise Exception("You need to initialize the chunk decoder firstly")
self.decoder.chunk_decoder_next(self.chunk_decoder, probs, probs_len)

def decode_get_trans(self):
if self.chunk_decoder is None:
raise Exception("You need to initialize the chunk decoder firstly")
trans_best, trans_beam = self.decoder.chunk_decoder_decode(
self.chunk_decoder)
return trans_best, trans_beam

def del_chunk_decoder(self):
if self.chunk_decoder is not None:
del self.chunk_decoder
self.chunk_decoder = None
return

@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Expand Down
58 changes: 58 additions & 0 deletions paddlespeech/s2t/modules/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import get_ctc_beam_search_chunk_decoder # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import get_ctc_beam_search_decoder_batch_class
except ImportError:
try:
from paddlespeech.s2t.utils import dynamic_pip_install
Expand Down Expand Up @@ -242,6 +244,7 @@ def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method):

self.decoding_method = decoding_method
if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
Expand Down Expand Up @@ -288,3 +291,58 @@ def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
else:
raise ValueError(f"Not support: {decoding_method}")
return result_transcripts

def get_chunk_decoder(self, vocabulary, batch_size, beam_alpha, beam_beta,
beam_size, num_processes, cutoff_prob, cutoff_top_n):
num_processes = min(num_processes, batch_size)
if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
if self.decoding_method == "ctc_beam_search":
DecoderClass = get_ctc_beam_search_decoder_batch_class()
chunk_decoder = DecoderClass(
vocabulary, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, self._ext_scorer, self.blank_id)
else:
raise ValueError(f"Not support: {decoding_method}")
return chunk_decoder

def chunk_decoder_next(self, chunk_decoder, probs, logits_lens):
has_value = (logits_lens > 0).tolist()
has_value = [
"true" if has_value[i] is True else "false"
for i in range(len(has_value))
]
"""
for i in range(len(has_value)):
if(has_value[i] == True):
has_value[i] = "true"
else:
has_value[i] = "false"
"""
probs_split = [
probs[i, :l, :].tolist() if has_value[i] else probs[i].tolist()
for i, l in enumerate(logits_lens)
]
if self.decoding_method == "ctc_beam_search":
chunk_decoder.next(probs_split, has_value)
else:
raise ValueError(f"Not support: {decoding_method}")

return

def chunk_decoder_decode(self, chunk_decoder):
if self.decoding_method == "ctc_beam_search":
batch_beam_results = chunk_decoder.decode()
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
results_best = [result[0][1] for result in batch_beam_results]
results_beam = [[trans[1] for trans in result]
for result in batch_beam_results]

else:
raise ValueError(f"Not support: {decoding_method}")

return results_best, results_beam

def remove_chunk_decoder(self, chunk_decoder):
del chunk_decoder
Loading

0 comments on commit 4756c7d

Please sign in to comment.