diff --git a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py index 01c5ea942cf1..fcf5d5cebfeb 100644 --- a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py +++ b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py @@ -154,7 +154,7 @@ def launch_spec_augment_kernel( if time_masks > 0 or freq_masks > 0: # Parallelize over freq and time axis, parallel threads over batch # Sequential over masks (adaptive in time). - blocks_per_grid = [sh[1], sh[2]] + blocks_per_grid = tuple([sh[1], sh[2]]) # threads_per_block = min(MAX_THREAD_BUFFER, max(freq_masks, time_masks)) threads_per_block = min(MAX_THREAD_BUFFER, x.shape[0]) diff --git a/nemo/core/utils/numba_utils.py b/nemo/core/utils/numba_utils.py index 6e1a8cb247d6..c11ea34f5ce9 100644 --- a/nemo/core/utils/numba_utils.py +++ b/nemo/core/utils/numba_utils.py @@ -17,6 +17,8 @@ import operator import os +import numba + from nemo.utils import model_utils # Prevent Numba CUDA logs from showing at info level @@ -159,4 +161,6 @@ def skip_numba_cuda_test_if_unsupported(min_version: str): if not numba_cuda_support: import pytest - pytest.skip(f"Numba cuda test is being skipped. Minimum version required : {min_version}") + pytest.skip( + f"Numba cuda test is being skipped. Minimum version required : {min_version}, found {numba.version_info}" + )