From 9bac71d1776d015edd1841649132187a1dcac295 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 5 Oct 2022 18:28:44 -0700 Subject: [PATCH] Fix Unhashable type list for Numba Cuda spec augment kernel (#5093) (#5094) Signed-off-by: smajumdar Signed-off-by: smajumdar Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Hainan Xu --- .../asr/parts/numba/spec_augment/spec_aug_numba.py | 2 +- nemo/core/utils/numba_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py index 01c5ea942cf1..fcf5d5cebfeb 100644 --- a/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py +++ b/nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py @@ -154,7 +154,7 @@ def launch_spec_augment_kernel( if time_masks > 0 or freq_masks > 0: # Parallelize over freq and time axis, parallel threads over batch # Sequential over masks (adaptive in time). - blocks_per_grid = [sh[1], sh[2]] + blocks_per_grid = tuple([sh[1], sh[2]]) # threads_per_block = min(MAX_THREAD_BUFFER, max(freq_masks, time_masks)) threads_per_block = min(MAX_THREAD_BUFFER, x.shape[0]) diff --git a/nemo/core/utils/numba_utils.py b/nemo/core/utils/numba_utils.py index 6e1a8cb247d6..c11ea34f5ce9 100644 --- a/nemo/core/utils/numba_utils.py +++ b/nemo/core/utils/numba_utils.py @@ -17,6 +17,8 @@ import operator import os +import numba + from nemo.utils import model_utils # Prevent Numba CUDA logs from showing at info level @@ -159,4 +161,6 @@ def skip_numba_cuda_test_if_unsupported(min_version: str): if not numba_cuda_support: import pytest - pytest.skip(f"Numba cuda test is being skipped. Minimum version required : {min_version}") + pytest.skip( + f"Numba cuda test is being skipped. Minimum version required : {min_version}, found {numba.version_info}" + )