From 1924153660635b581d70b5961313413e93615786 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 24 Jan 2022 04:04:33 +0000 Subject: [PATCH] fix a bug --- paddlespeech/s2t/exps/deepspeech2/model.py | 23 +++++++++++----------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index cc4dcdf1382..f43511baadb 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -413,19 +413,19 @@ def test(self): def compute_result_transcripts(self, audio, audio_len): if self.args.model_type == "online": - output_probs, output_lens, trans = self.static_forward_online( + output_probs, output_lens, trans_batch = self.static_forward_online( audio, audio_len, decoder_chunk_size=1) - result_transcripts = trans[-1:] + result_transcripts = [trans[-1] for trans in trans_batch] elif self.args.model_type == "offline": - batch_size = output_probs.shape[0] - self.model.decoder.reset_decoder(batch_size = batch_size) output_probs, output_lens = self.static_forward_offline(audio, audio_len) + batch_size = output_probs.shape[0] + self.model.decoder.reset_decoder(batch_size=batch_size) self.model.decoder.next(output_probs, output_lens) trans_best, trans_beam = self.model.decoder.decode() - + result_transcripts = trans_best else: @@ -485,7 +485,7 @@ def static_forward_online(self, audio, audio_len, x_list = np.split(x_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, batch_size, axis=0) - trans = [] + trans_batch = [] for x, x_len in zip(x_list, x_len_list): if self.args.enable_auto_log is True: self.autolog.times.start() @@ -518,14 +518,14 @@ def static_forward_online(self, audio, audio_len, h_box_handle = self.predictor.get_input_handle(input_names[2]) c_box_handle = self.predictor.get_input_handle(input_names[3]) - trans_chunk_list = [] + trans = [] probs_chunk_list = [] probs_chunk_lens_list = [] if self.args.enable_auto_log is True: # record the model preprocessing time self.autolog.times.stamp() - - self.model.decoder.reset_decoder(batch_size = 1) + + self.model.decoder.reset_decoder(batch_size=1) for i in range(0, num_chunk): start = i * chunk_stride end = start + chunk_size @@ -569,8 +569,7 @@ def static_forward_online(self, audio, audio_len, probs_chunk_lens_list.append(output_chunk_lens) trans_best, trans_beam = self.model.decoder.decode() trans.append(trans_best[0]) - - + trans_batch.append(trans) output_probs = np.concatenate(probs_chunk_list, axis=1) output_lens = np.sum(probs_chunk_lens_list, axis=0) vocab_size = output_probs.shape[2] @@ -592,7 +591,7 @@ def static_forward_online(self, audio, audio_len, self.autolog.times.end() output_probs = np.concatenate(output_probs_list, axis=0) output_lens = np.concatenate(output_lens_list, axis=0) - return output_probs, output_lens, trans + return output_probs, output_lens, trans_batch def static_forward_offline(self, audio, audio_len): """