Skip to content

Commit

Permalink
[Enhancement]Revise collect_env for win platform (#112)
Browse files Browse the repository at this point in the history
* Revise collect_env for win platform

* add ut

* fix ut

* remove win ut
  • Loading branch information
MeowZheng authored Mar 29, 2022
1 parent 34fbe52 commit a8ee90e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
9 changes: 6 additions & 3 deletions mmflow/utils/collect_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def collect_env() -> dict:
for name, devids in devices.items():
env_info['GPU ' + ','.join(devids)] = name

gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
gcc = gcc.decode('utf-8').strip()
env_info['GCC'] = gcc
try:
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
gcc = gcc.decode('utf-8').strip()
env_info['GCC'] = gcc
except subprocess.CalledProcessError: # gcc is unavailable
env_info['GCC'] = 'n/a'

env_info['PyTorch'] = torch.__version__
env_info['PyTorch compiling details'] = get_build_config()
Expand Down
29 changes: 29 additions & 0 deletions tests/test_utils/test_set_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import multiprocessing as mp
import os
import platform
import sys

import cv2
import mmcv
import pytest
from mmcv import Config

import mmflow
from mmflow.utils import setup_multi_processes


Expand Down Expand Up @@ -85,3 +88,29 @@ def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
assert cv2.getNumThreads() == sys_cv_threads
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ


def test_collect_env():
try:
import torch # noqa: F401
except ModuleNotFoundError:
pytest.skip('skipping tests that require PyTorch')

from mmflow.utils import collect_env
env_info = collect_env()
expected_keys = [
'sys.platform', 'Python', 'CUDA available', 'PyTorch',
'PyTorch compiling details', 'OpenCV', 'MMCV', 'MMCV Compiler',
'MMCV CUDA Compiler', 'MMFlow', 'GCC'
]
for key in expected_keys:
assert key in env_info

if env_info['CUDA available']:
for key in ['CUDA_HOME', 'NVCC']:
assert key in env_info

assert env_info['sys.platform'] == sys.platform
assert env_info['Python'] == sys.version.replace('\n', '')
assert env_info['MMCV'] == mmcv.__version__
assert mmflow.__version__ in env_info['MMFlow']

0 comments on commit a8ee90e

Please sign in to comment.