diff --git a/tests/estimators/speech_recognition/test_pytorch_icefall.py b/tests/estimators/speech_recognition/test_pytorch_icefall.py index 65e7d5ddaf..ded9308afd 100644 --- a/tests/estimators/speech_recognition/test_pytorch_icefall.py +++ b/tests/estimators/speech_recognition/test_pytorch_icefall.py @@ -46,6 +46,7 @@ def test_pytorch_icefall(art_warning, expected_values, device_type): # load_model_ensemble transducer_model = get_transducer_model(params) + transducer_model.device = 'cpu' word2ids = get_word2id(params) get_id2word = get_id2word(params) model_ensemble = {