From a984b93a51a33c70086d29b93aa1be0059e9de7e Mon Sep 17 00:00:00 2001 From: Xinyuan Li Date: Tue, 16 Jan 2024 16:59:52 -0500 Subject: [PATCH] Small fix Signed-off-by: Xinyuan Li --- art/estimators/speech_recognition/pytorch_icefall.py | 9 ++++----- .../speech_recognition/test_pytorch_icefall.py | 7 +++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/art/estimators/speech_recognition/pytorch_icefall.py b/art/estimators/speech_recognition/pytorch_icefall.py index a3a9c7429a..0b9a246c67 100644 --- a/art/estimators/speech_recognition/pytorch_icefall.py +++ b/art/estimators/speech_recognition/pytorch_icefall.py @@ -156,17 +156,16 @@ def predict(self, x: np.ndarray, batch_size: int = 1, **kwargs) -> np.ndarray: num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size))) for sample_index in range(num_batch): - wav = x_preprocessed[sample_index] # np.array, len = wav len - shape = wav.shape + wav = x_preprocessed[sample_index] # np.array, len = wav len # extract features - x = self.transform_model_input(x=torch.tensor(wav)) + x, _, _ = self.transform_model_input(x=torch.tensor(wav)) + shape = torch.tensor([x.shape[1]]) - print(shape) encoder_out, encoder_out_lens = self.transducer_model.encoder(x=x, x_lens=shape) hyp = greedy_search(model=self.transducer_model, encoder_out=encoder_out, id2word=self.get_id2word) decoded_output.append(hyp) - + return np.concatenate(decoded_output) def loss_gradient(self, x, y: np.ndarray, **kwargs) -> np.ndarray: diff --git a/tests/estimators/speech_recognition/test_pytorch_icefall.py b/tests/estimators/speech_recognition/test_pytorch_icefall.py index a80e7855df..65e7d5ddaf 100644 --- a/tests/estimators/speech_recognition/test_pytorch_icefall.py +++ b/tests/estimators/speech_recognition/test_pytorch_icefall.py @@ -88,10 +88,9 @@ def test_pytorch_icefall(art_warning, expected_values, device_type): # Test transcription outputs hyps = [] - print(xs[0]) - hyps.append(speech_recognizer.predict(np.array(xs[0]))) - hyps.append(speech_recognizer.predict(np.array(xs[1]))) - hyps.append(speech_recognizer.predict(np.array(xs[2]))) + hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[0]), 0))) + hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[1]), 0))) + hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[2]), 0))) print(hyps) assert (np.array(hyps) == y).all()