Skip to content

Commit

Permalink
Fix Unhashable type list for Numba Cuda spec augment kernel (NVIDIA#5093
Browse files Browse the repository at this point in the history
) (NVIDIA#5094)

Signed-off-by: smajumdar <smajumdar@nvidia.com>

Signed-off-by: smajumdar <smajumdar@nvidia.com>

Signed-off-by: smajumdar <smajumdar@nvidia.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Signed-off-by: 1-800-bad-code <shane.carroll@utsa.edu>
  • Loading branch information
2 people authored and 1-800-BAD-CODE committed Nov 13, 2022
1 parent 112d37e commit 139517f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def launch_spec_augment_kernel(
if time_masks > 0 or freq_masks > 0:
# Parallelize over freq and time axis, parallel threads over batch
# Sequential over masks (adaptive in time).
blocks_per_grid = [sh[1], sh[2]]
blocks_per_grid = tuple([sh[1], sh[2]])
# threads_per_block = min(MAX_THREAD_BUFFER, max(freq_masks, time_masks))
threads_per_block = min(MAX_THREAD_BUFFER, x.shape[0])

Expand Down
6 changes: 5 additions & 1 deletion nemo/core/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import operator
import os

import numba

from nemo.utils import model_utils

# Prevent Numba CUDA logs from showing at info level
Expand Down Expand Up @@ -159,4 +161,6 @@ def skip_numba_cuda_test_if_unsupported(min_version: str):
if not numba_cuda_support:
import pytest

pytest.skip(f"Numba cuda test is being skipped. Minimum version required : {min_version}")
pytest.skip(
f"Numba cuda test is being skipped. Minimum version required : {min_version}, found {numba.version_info}"
)

0 comments on commit 139517f

Please sign in to comment.