Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackwaterveg committed Jan 24, 2022
1 parent eb4edad commit 1924153
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions paddlespeech/s2t/exps/deepspeech2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,19 +413,19 @@ def test(self):

def compute_result_transcripts(self, audio, audio_len):
if self.args.model_type == "online":
output_probs, output_lens, trans = self.static_forward_online(
output_probs, output_lens, trans_batch = self.static_forward_online(
audio, audio_len, decoder_chunk_size=1)
result_transcripts = trans[-1:]
result_transcripts = [trans[-1] for trans in trans_batch]
elif self.args.model_type == "offline":
batch_size = output_probs.shape[0]
self.model.decoder.reset_decoder(batch_size = batch_size)
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
batch_size = output_probs.shape[0]
self.model.decoder.reset_decoder(batch_size=batch_size)

self.model.decoder.next(output_probs, output_lens)

trans_best, trans_beam = self.model.decoder.decode()

result_transcripts = trans_best

else:
Expand Down Expand Up @@ -485,7 +485,7 @@ def static_forward_online(self, audio, audio_len,
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)

trans = []
trans_batch = []
for x, x_len in zip(x_list, x_len_list):
if self.args.enable_auto_log is True:
self.autolog.times.start()
Expand Down Expand Up @@ -518,14 +518,14 @@ def static_forward_online(self, audio, audio_len,
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])

trans_chunk_list = []
trans = []
probs_chunk_list = []
probs_chunk_lens_list = []
if self.args.enable_auto_log is True:
# record the model preprocessing time
self.autolog.times.stamp()
self.model.decoder.reset_decoder(batch_size = 1)

self.model.decoder.reset_decoder(batch_size=1)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
Expand Down Expand Up @@ -569,8 +569,7 @@ def static_forward_online(self, audio, audio_len,
probs_chunk_lens_list.append(output_chunk_lens)
trans_best, trans_beam = self.model.decoder.decode()
trans.append(trans_best[0])


trans_batch.append(trans)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
Expand All @@ -592,7 +591,7 @@ def static_forward_online(self, audio, audio_len,
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens, trans
return output_probs, output_lens, trans_batch

def static_forward_offline(self, audio, audio_len):
"""
Expand Down

0 comments on commit 1924153

Please sign in to comment.