Skip to content

Commit

Permalink
Small fix
Browse files Browse the repository at this point in the history
Signed-off-by: Xinyuan Li <xli257@b17.clsp.jhu.edu>
  • Loading branch information
Xinyuan Li committed Jan 16, 2024
1 parent 7d7cc2e commit a984b93
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
9 changes: 4 additions & 5 deletions art/estimators/speech_recognition/pytorch_icefall.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,16 @@ def predict(self, x: np.ndarray, batch_size: int = 1, **kwargs) -> np.ndarray:
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))

for sample_index in range(num_batch):
wav = x_preprocessed[sample_index] # np.array, len = wav len
shape = wav.shape
wav = x_preprocessed[sample_index] # np.array, len = wav len

# extract features
x = self.transform_model_input(x=torch.tensor(wav))
x, _, _ = self.transform_model_input(x=torch.tensor(wav))
shape = torch.tensor([x.shape[1]])

print(shape)
encoder_out, encoder_out_lens = self.transducer_model.encoder(x=x, x_lens=shape)
hyp = greedy_search(model=self.transducer_model, encoder_out=encoder_out, id2word=self.get_id2word)
decoded_output.append(hyp)

return np.concatenate(decoded_output)

def loss_gradient(self, x, y: np.ndarray, **kwargs) -> np.ndarray:
Expand Down
7 changes: 3 additions & 4 deletions tests/estimators/speech_recognition/test_pytorch_icefall.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def test_pytorch_icefall(art_warning, expected_values, device_type):

# Test transcription outputs
hyps = []
print(xs[0])
hyps.append(speech_recognizer.predict(np.array(xs[0])))
hyps.append(speech_recognizer.predict(np.array(xs[1])))
hyps.append(speech_recognizer.predict(np.array(xs[2])))
hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[0]), 0)))
hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[1]), 0)))
hyps.append(speech_recognizer.predict(np.expand_dims(np.array(xs[2]), 0)))
print(hyps)

assert (np.array(hyps) == y).all()
Expand Down

0 comments on commit a984b93

Please sign in to comment.