diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 0f70631f140f..42a0b108944d 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -576,7 +576,7 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op f"Could not check version compatibility." ) return False, msg - except (ImportError, ModuleNotFoundError, AttributeError): + except (ImportError, ModuleNotFoundError): pass msg = f"Lib {lib_name} has not been installed. Please use pip or conda to install this package." diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py index aa96b6753849..727742fdffb5 100644 --- a/tests/collections/nlp/test_flash_attention.py +++ b/tests/collections/nlp/test_flash_attention.py @@ -44,23 +44,16 @@ except (ImportError, ModuleNotFoundError): HAVE_TRITON = False -try: - import pynvml - - HAVE_PYNVML = True -except (ImportError, ModuleNotFoundError): - HAVE_PYNVML = False +import pynvml def HAVE_AMPERE_GPU(): - if HAVE_PYNVML: - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - device_arch = pynvml.nvmlDeviceGetArchitecture(handle) - pynvml.nvmlShutdown() - return device_arch == pynvml.NVML_DEVICE_ARCH_AMPERE - else: - return False + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + device_arch = pynvml.nvmlDeviceGetArchitecture(handle) + pynvml.nvmlShutdown() + return device_arch == pynvml.NVML_DEVICE_ARCH_AMPERE + @pytest.mark.run_only_on('GPU') @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed")