From a8ee90ee5d14a7537773f536e3404094a821d5a8 Mon Sep 17 00:00:00 2001 From: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Date: Tue, 29 Mar 2022 10:48:32 +0800 Subject: [PATCH] [Enhancement]Revise collect_env for win platform (#112) * Revise collect_env for win platform * add ut * fix ut * remove win ut --- mmflow/utils/collect_env.py | 9 ++++++--- tests/test_utils/test_set_env.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/mmflow/utils/collect_env.py b/mmflow/utils/collect_env.py index 6c967723..7971b1f3 100644 --- a/mmflow/utils/collect_env.py +++ b/mmflow/utils/collect_env.py @@ -48,9 +48,12 @@ def collect_env() -> dict: for name, devids in devices.items(): env_info['GPU ' + ','.join(devids)] = name - gcc = subprocess.check_output('gcc --version | head -n1', shell=True) - gcc = gcc.decode('utf-8').strip() - env_info['GCC'] = gcc + try: + gcc = subprocess.check_output('gcc --version | head -n1', shell=True) + gcc = gcc.decode('utf-8').strip() + env_info['GCC'] = gcc + except subprocess.CalledProcessError: # gcc is unavailable + env_info['GCC'] = 'n/a' env_info['PyTorch'] = torch.__version__ env_info['PyTorch compiling details'] = get_build_config() diff --git a/tests/test_utils/test_set_env.py b/tests/test_utils/test_set_env.py index fb77234e..57d34177 100644 --- a/tests/test_utils/test_set_env.py +++ b/tests/test_utils/test_set_env.py @@ -2,11 +2,14 @@ import multiprocessing as mp import os import platform +import sys import cv2 +import mmcv import pytest from mmcv import Config +import mmflow from mmflow.utils import setup_multi_processes @@ -85,3 +88,29 @@ def test_setup_multi_processes(workers_per_gpu, valid, env_cfg): assert cv2.getNumThreads() == sys_cv_threads assert 'OMP_NUM_THREADS' not in os.environ assert 'MKL_NUM_THREADS' not in os.environ + + +def test_collect_env(): + try: + import torch # noqa: F401 + except ModuleNotFoundError: + pytest.skip('skipping tests that require PyTorch') + + from mmflow.utils import collect_env + env_info = collect_env() + expected_keys = [ + 'sys.platform', 'Python', 'CUDA available', 'PyTorch', + 'PyTorch compiling details', 'OpenCV', 'MMCV', 'MMCV Compiler', + 'MMCV CUDA Compiler', 'MMFlow', 'GCC' + ] + for key in expected_keys: + assert key in env_info + + if env_info['CUDA available']: + for key in ['CUDA_HOME', 'NVCC']: + assert key in env_info + + assert env_info['sys.platform'] == sys.platform + assert env_info['Python'] == sys.version.replace('\n', '') + assert env_info['MMCV'] == mmcv.__version__ + assert mmflow.__version__ in env_info['MMFlow']