diff --git a/.travis.yml b/.travis.yml index 6d7aa11c0..b05d7dde4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -50,7 +50,7 @@ install: before_script: - export CAFFE_ROOT=~/caffe - - export TORCH_ROOT=~/torch/install + - export TORCH_ROOT=~/torch # Disable OpenMP multi-threading - export OMP_NUM_THREADS=1 diff --git a/digits-devserver b/digits-devserver index d48969468..e41e85832 100755 --- a/digits-devserver +++ b/digits-devserver @@ -5,6 +5,9 @@ import argparse import sys import digits +import digits.config +import digits.log +from digits.webapp import app, socketio, scheduler if __name__ == '__main__': parser = argparse.ArgumentParser(description='Run the DIGITS development server') @@ -13,10 +16,6 @@ if __name__ == '__main__': default=5000, help='Port to run app on (default 5000)' ) - parser.add_argument('-c', '--config', - action='store_true', - help='Edit the application configuration' - ) parser.add_argument('-d', '--debug', action='store_true', help='Run the application in debug mode (reloads when the source changes and gives more detailed error messages)' @@ -28,19 +27,10 @@ if __name__ == '__main__': args = vars(parser.parse_args()) - from digits import config - if args['version']: print digits.__version__ sys.exit() - if args['config']: - config.load_config('normal') - else: - config.load_config('quiet') - - from digits.webapp import app, socketio, scheduler - print ' ___ ___ ___ ___ _____ ___' print ' | \_ _/ __|_ _|_ _/ __|' print ' | |) | | (_ || | | | \__ \\' diff --git a/digits/config/__init__.py b/digits/config/__init__.py index f5eb3d34e..f7e5cdbf6 100644 --- a/digits/config/__init__.py +++ b/digits/config/__init__.py @@ -1,14 +1,21 @@ # Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import -import os +# Create this object before importing the following imports, since they edit the list +option_list = {} -# These are the only two functions that the rest of DIGITS needs to use -from .current_config import config_value -from .load import load_config +from . import caffe +from . import gpu_list +from . import jobs_dir +from . import log_file +from . import torch +from . import server_name +from . import extension_list # Import this last, since it imports other things inside DIGITS -if 'DIGITS_MODE_TEST' in os.environ: - # load the config automatically during testing - # it's hard to do it manually with nosetests - load_config() + +def config_value(option): + """ + Return the current configuration value for the given option + """ + return option_list[option] diff --git a/digits/config/caffe.py b/digits/config/caffe.py new file mode 100644 index 000000000..eeeb5b319 --- /dev/null +++ b/digits/config/caffe.py @@ -0,0 +1,236 @@ +# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +import imp +import os +import platform +import re +import subprocess +import sys + +from . import option_list +from digits import device_query +from digits.utils import parse_version +from digits.utils.errors import UnsupportedPlatformError + + +def load_from_envvar(envvar): + """ + Load information from an installation indicated by an environment variable + """ + value = os.environ[envvar].strip().strip("\"' ") + + executable_dir = os.path.join(value, 'build', 'tools') + python_dir = os.path.join(value, 'python') + + try: + executable = find_executable_in_dir(executable_dir) + if executable is None: + raise ValueError('Caffe executable not found at "%s"' + % executable_dir) + if not is_pycaffe_in_dir(python_dir): + raise ValueError('Pycaffe not found in "%s"' + % python_dir) + import_pycaffe(python_dir) + version, flavor = get_version_and_flavor(executable) + except: + print ('"%s" from %s does not point to a valid installation of Caffe.' + % (value, envvar)) + print 'Use the envvar CAFFE_ROOT to indicate a valid installation.' + raise + return executable, version, flavor + + +def load_from_path(): + """ + Load information from an installation on standard paths (PATH and PYTHONPATH) + """ + try: + executable = find_executable_in_dir() + if executable is None: + raise ValueError('Caffe executable not found in PATH') + if not is_pycaffe_in_dir(): + raise ValueError('Pycaffe not found in PYTHONPATH') + import_pycaffe() + version, flavor = get_version_and_flavor(executable) + except: + print 'A valid Caffe installation was not found on your system.' + print 'Use the envvar CAFFE_ROOT to indicate a valid installation.' + raise + return executable, version, flavor + + +def find_executable_in_dir(dirname=None): + """ + Returns the path to the caffe executable at dirname + If dirname is None, search all directories in sys.path + Returns None if not found + """ + if platform.system() == 'windows': + exe_name = 'caffe.exe' + else: + exe_name = 'caffe' + + if dirname is None: + dirnames = [path.strip("\"' ") for path in os.environ['PATH'].split(os.pathsep)] + else: + dirnames = [dirname] + + for dirname in dirnames: + path = os.path.join(dirname, exe_name) + if os.path.isfile(path) and os.access(path, os.X_OK): + return path + return None + + +def is_pycaffe_in_dir(dirname=None): + """ + Returns True if you can "import caffe" from dirname + If dirname is None, search all directories in sys.path + """ + old_path = sys.path + if dirname is not None: + sys.path = [dirname] # temporarily replace sys.path + try: + imp.find_module('caffe') + except ImportError as e: + return False + finally: + sys.path = old_path + return True + + +def import_pycaffe(dirname=None): + """ + Imports caffe + If dirname is not None, prepend it to sys.path first + """ + if dirname is not None: + sys.path.insert(0, dirname) + # Add to PYTHONPATH so that build/tools/caffe is aware of python layers there + os.environ['PYTHONPATH'] = '%s%s%s' % ( + dirname, os.pathsep, os.environ.get('PYTHONPATH')) + + # Suppress GLOG output for python bindings + GLOG_minloglevel = os.environ.pop('GLOG_minloglevel', None) + # Show only "ERROR" and "FATAL" + os.environ['GLOG_minloglevel'] = '2' + + # for Windows environment, loading h5py before caffe solves the issue mentioned in + # https://github.com/NVIDIA/DIGITS/issues/47#issuecomment-206292824 + import h5py + try: + import caffe + except ImportError: + print 'Did you forget to "make pycaffe"?' + raise + + # Strange issue with protocol buffers and pickle - see issue #32 + sys.path.insert(0, os.path.join( + os.path.dirname(caffe.__file__), 'proto')) + + # Turn GLOG output back on for subprocess calls + if GLOG_minloglevel is None: + del os.environ['GLOG_minloglevel'] + else: + os.environ['GLOG_minloglevel'] = GLOG_minloglevel + + +def get_version_and_flavor(executable): + """ + Returns (version, flavor) + Should be called after import_pycaffe() + """ + version_string = get_version_from_pycaffe() + if version_string is None: + version_string = get_version_from_cmdline(executable) + if version_string is None: + version_string = get_version_from_soname(executable) + + if version_string is None: + raise ValueError('Could not find version information for Caffe build ' + + 'at "%s". Upgrade your installation' % executable) + + version = parse_version(version_string) + + if parse_version(0,99,0) > version > parse_version(0,9,0): + flavor = 'NVIDIA' + minimum_version = '0.11.0' + if version < parse_version(minimum_version): + raise ValueError( + 'Required version "%s" is greater than "%s". Upgrade your installation.' + % (nvidia_minimum_version, version_string)) + else: + flavor = 'BVLC' + + return version_string, flavor + + +def get_version_from_pycaffe(): + import caffe + try: + from caffe import __version__ as version + return version + except AttributeError: + return None + + +def get_version_from_cmdline(executable): + command = [executable, '-version'] + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.wait(): + print p.stderr.read().strip() + raise RuntimeError('"%s" returned error code %s' % (command, p.returncode)) + + for line in p.stdout: + if 'version' in line: + return line[line.find(pattern) + len(pattern)+1:].strip() + return None + + +def get_version_from_soname(executable): + command = ['ldd', executable] + p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.wait(): + print p.stderr.read().strip() + raise RuntimeError('"%s" returned error code %s' % (command, p.returncode)) + + # Search output for caffe library + caffe_line = None + for line in p.stdout: + if 'libcaffe' in line: + caffe_line = line + break + + if caffe_line is None: + raise ValueError('libcaffe not found in linked libraries for "%s"' + % executable) + + # Read the symlink for libcaffe from ldd output + symlink = caffe_line.split()[2] + filename = os.path.basename(os.path.realpath(symlink)) + + # parse the version string + match = re.match(r'%s(.*)\.so\.(\S+)$' % (libname), filename) + if match: + return match.group(2) + else: + return None + + + +if 'CAFFE_ROOT' in os.environ: + executable, version, flavor = load_from_envvar('CAFFE_ROOT') +elif 'CAFFE_HOME' in os.environ: + executable, version, flavor = load_from_envvar('CAFFE_HOME') +else: + executable, version, flavor = load_from_path() + +option_list['caffe'] = { + 'executable': executable, + 'version': version, + 'flavor': flavor, + 'multi_gpu': (flavor == 'BVLC' or parse_version(version) >= parse_version(0,12)), + 'cuda_enabled': (len(device_query.get_devices()) > 0), +} + diff --git a/digits/config/caffe_option.py b/digits/config/caffe_option.py deleted file mode 100644 index f222c8d64..000000000 --- a/digits/config/caffe_option.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import imp -import os -import platform -import re -import subprocess -import sys - -from . import config_option -from . import prompt -from digits import device_query -from digits.utils import parse_version -from digits.utils.errors import UnsupportedPlatformError - -class CaffeOption(config_option.FrameworkOption): - - @staticmethod - def config_file_key(): - return 'caffe_root' - - @classmethod - def prompt_title(cls): - return 'Caffe' - - @classmethod - def prompt_message(cls): - return 'Where is caffe installed?' - - def optional(self): - #TODO: make this optional - return False - - def suggestions(self): - suggestions = [] - if 'CAFFE_ROOT' in os.environ: - d = os.environ['CAFFE_ROOT'] - try: - suggestions.append(prompt.Suggestion( - self.validate(d), 'R', - desc='CAFFE_ROOT', default=True)) - except config_option.BadValue: - pass - if 'CAFFE_HOME' in os.environ: - d = os.environ['CAFFE_HOME'] - try: - default = True - if len(suggestions) > 0: - default = False - suggestions.append(prompt.Suggestion( - self.validate(d), 'H', - desc='CAFFE_HOME', default=default)) - except config_option.BadValue: - pass - suggestions.append(prompt.Suggestion('', 'P', - desc='PATH/PYTHONPATH', default=True)) - return suggestions - - @staticmethod - def is_path(): - return True - - @classmethod - def validate(cls, value): - if not value: - return value - - if value == '': - # Find the executable - executable = cls.find_executable('caffe') - if not executable: - executable = cls.find_executable('caffe.exe') - if not executable: - raise config_option.BadValue('caffe binary not found in PATH') - cls.validate_version(executable) - - # Find the python module - try: - imp.find_module('caffe') - except ImportError: - raise config_option.BadValue('caffe python package not found in PYTHONPATH') - return value - else: - # Find the executable - value = os.path.abspath(value) - if not os.path.isdir(value): - raise config_option.BadValue('"%s" is not a directory' % value) - expected_path = os.path.join(value, 'build', 'tools', 'caffe') - if not os.path.exists(expected_path): - raise config_option.BadValue('caffe binary not found at "%s"' % value) - cls.validate_version(expected_path) - - # Find the python module - pythonpath = os.path.join(value, 'python') - sys.path.insert(0, pythonpath) - try: - imp.find_module('caffe') - except ImportError as e: - raise config_option.BadValue('Error while importing caffe from "%s": %s' % ( - pythonpath, e.message)) - finally: - # Don't actually add this until apply() is called - sys.path.pop(0) - - return value - - @staticmethod - def find_executable(program): - """ - Finds an executable by searching through PATH - Returns the path to the executable or None - """ - for path in os.environ['PATH'].split(os.pathsep): - path = path.strip('"') - executable = os.path.join(path, program) - if os.path.isfile(executable) and os.access(executable, os.X_OK): - return executable - return None - - @classmethod - def validate_version(cls, executable): - """ - Utility for checking the caffe version from within validate() - Throws BadValue - - Arguments: - executable -- path to a caffe executable - """ - nvidia_minimum_version = '0.11.0' - info_dict = cls.get_info(executable) - if info_dict['ver_str'] is None: - raise config_option.BadValue('Your Caffe does not have version info. Please upgrade it.') - else: - flavor = CaffeOption.get_flavor(info_dict['ver_str']) - if flavor == 'NVIDIA' and parse_version(nvidia_minimum_version) > parse_version(info_dict['ver_str']): - raise config_option.BadValue( - 'Required version "{min_ver}" is greater than "{running_ver}". '\ - 'Upgrade your installation.'\ - .format(min_ver = nvidia_minimum_version, running_ver = info_dict['ver_str'])) - else: - return True - - @staticmethod - def get_executable_version_string(executable): - """ - Returns the caffe version as either a string from results of command line option '-version' - or None if '-version' not implemented - - Arguments: - executable -- path to a caffe executable - """ - - supported_platforms = ['Windows', 'Linux', 'Darwin'] - version_string = None - if platform.system() in supported_platforms: - p = subprocess.Popen([executable, '-version'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if p.wait(): - raise config_option.BadValue(p.stderr.read().strip()) - else: - pattern = 'version' - for line in p.stdout: - if pattern in line: - version_string = line[line.find(pattern) + len(pattern)+1:].rstrip() - break - try: - parse_version(version_string) - except ValueError: #version_string is either ill-formatted or 'CAFFE_VERSION' - version_string = None - return version_string - else: - raise UnsupportedPlatformError('platform "%s" not supported' % platform.system()) - - @staticmethod - def get_linked_library_version_string(executable): - """ - Returns the information about executable's linked library name version - or None if error - - Arguments: - executable -- path to a caffe executable - """ - - version_string = None - - p = subprocess.Popen(['ldd', executable], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - if p.wait(): - raise config_option.BadValue(p.stderr.read().strip()) - libname = 'libcaffe' - caffe_line = None - - # Search output for caffe library - for line in p.stdout: - if libname in line: - caffe_line = line - break - if caffe_line is None: - raise config_option.BadValue('%s not found in ldd output' % libname) - - # Read the symlink for libcaffe from ldd output - symlink = caffe_line.split()[2] - filename = os.path.basename(os.path.realpath(symlink)) - - # parse the version string - match = re.match(r'%s(.*)\.so\.(\S+)$' - % (libname), filename) - if match: - version_string = match.group(2) - return version_string - - @staticmethod - def get_flavor(ver_str): - """ - Returns the information about caffe library enhancement (NVIDIA or BVLC) - - Arguments: - ver_str -- version string that can identify enhancement flavor - """ - if parse_version(0,99,0) > parse_version(ver_str) > parse_version(0,9,0): - return 'NVIDIA' - else: - return 'BVLC' - - @staticmethod - def get_version_string(executable): - """ - Returns the caffe version as a string from executable or linked library name - - Arguments: - executable -- path to a caffe executable - """ - - version_string = CaffeOption.get_executable_version_string(executable) - if not version_string and platform.system() == 'Linux': - version_string = CaffeOption.get_linked_library_version_string(executable) - return version_string - - @staticmethod - def get_info(executable): - """ - Returns the caffe info a dict {'ver_str', 'flavor'} - values of dict are None if unable to get version. - - Arguments: - executable -- path to a caffe executable - """ - try: - version_string = CaffeOption.get_version_string(executable) - except UnsupportedPlatformError: - return {'ver_str': None, 'flavor': None} - - if version_string: - return {'ver_str': version_string, 'flavor': CaffeOption.get_flavor(version_string)} - else: - return {'ver_str': None, 'flavor': None} - - def _set_config_dict_value(self, value): - if not value: - self._config_dict_value = None - else: - if value == '': - executable = self.find_executable('caffe') - if not executable: - executable = self.find_executable('caffe.exe') - else: - executable = os.path.join(value, 'build', 'tools', 'caffe') - - info_dict = self.get_info(executable) - version = parse_version(info_dict['ver_str']) - if version >= parse_version(0,12): - multi_gpu = True - else: - multi_gpu = False - - flavor = info_dict['flavor'] - # TODO: ask caffe for this information - cuda_enabled = len(device_query.get_devices()) > 0 - - self._config_dict_value = { - 'executable': executable, - 'version': version, - 'ver_str': info_dict['ver_str'], - 'multi_gpu': multi_gpu, - 'cuda_enabled': cuda_enabled, - 'flavor': flavor - } - - def apply(self): - if self._config_file_value: - # Suppress GLOG output for python bindings - GLOG_minloglevel = os.environ.pop('GLOG_minloglevel', None) - # Show only "ERROR" and "FATAL" - os.environ['GLOG_minloglevel'] = '2' - - if self._config_file_value != '': - # Add caffe/python to PATH - p = os.path.join(self._config_file_value, 'python') - sys.path.insert(0, p) - # Add caffe/python to PYTHONPATH - # so that build/tools/caffe is aware of python layers there - os.environ['PYTHONPATH'] = '%s:%s' % (p, os.environ.get('PYTHONPATH')) - - # for Windows environment, loading h5py before caffe solves the issue mentioned in - # https://github.com/NVIDIA/DIGITS/issues/47#issuecomment-206292824 - import h5py - try: - import caffe - except ImportError: - print 'Did you forget to "make pycaffe"?' - raise - - # Strange issue with protocol buffers and pickle - see issue #32 - sys.path.insert(0, os.path.join( - os.path.dirname(caffe.__file__), 'proto')) - - # Turn GLOG output back on for subprocess calls - if GLOG_minloglevel is None: - del os.environ['GLOG_minloglevel'] - else: - os.environ['GLOG_minloglevel'] = GLOG_minloglevel - - diff --git a/digits/config/config_file.py b/digits/config/config_file.py deleted file mode 100644 index aed5d4ce0..000000000 --- a/digits/config/config_file.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -from collections import OrderedDict -import ConfigParser -import os -import platform - -import digits - -class ConfigFile(object): - """ - Handles IO on a config file - """ - config_section = 'DIGITS' - - def __init__(self, filename): - """ - Doesn't make a fuss if the file doesn't exist - Use exists() to check - """ - self._filename = filename - self._options = OrderedDict() - self.load() - self._dirty = False - - def __str__(self): - s = '' - for item in self._options.iteritems(): - s += '%15s = %s\n' % item - return s - - def filename(self): - return self._filename - - def exists(self): - """ - Returns True if the file exists - """ - return self._filename is not None and os.path.isfile(self._filename) - - def can_read(self): - """ - Returns True if the file can be read - """ - return self.exists() and os.access(self._filename, os.R_OK) - - def can_write(self): - """ - Returns True if the file can be written - """ - if os.path.isfile(self._filename): - return os.access(self._filename, os.W_OK) - else: - return os.access( - os.path.dirname(self._filename), - os.W_OK) - - def load(self): - """ - Load options from the file - Overwrites any values in self._options - Returns True if the file loaded successfully - """ - if not self.exists(): - return False - cfg = ConfigParser.SafeConfigParser() - cfg.read(self._filename) - if not cfg.has_section(self.config_section): - raise ValueError('expected section "%s" in config file at "%s"' % ( - self.config_section, self._filename)) - - for key, val in cfg.items(self.config_section): - self._options[key] = val - return True - - def get(self, name): - """ - Get a config option by name - """ - if name in self._options: - return self._options[name] - else: - return None - - def set(self, name, value): - """ - Set a config option by name - """ - if value is None: - if name in self._options: - del self._options[name] - self._dirty = True - else: - if not (name in self._options and self._options[name] == value): - self._dirty = True - self._options[name] = value - - def dirty(self): - """ - Returns True if there are changes to be written to disk - """ - return self._dirty - - def save(self): - """ - Save config file to disk - """ - cfg = ConfigParser.SafeConfigParser() - cfg.add_section(self.config_section) - for name, value in self._options.iteritems(): - cfg.set(self.config_section, name, value) - with open(self._filename, 'w') as outfile: - cfg.write(outfile) - - -class InstanceConfigFile(ConfigFile): - def __init__(self): - filename = os.path.join(os.path.dirname(digits.__file__), 'digits.cfg') - super(InstanceConfigFile, self).__init__(filename) diff --git a/digits/config/config_option.py b/digits/config/config_option.py deleted file mode 100644 index 7022d1a51..000000000 --- a/digits/config/config_option.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. - -class BadValue(Exception): - # thrown when a bad value is passed to option.set() - pass - -class Visibility(object): - NEVER, HIDDEN, DEFAULT = range(3) - -class Option(object): - """ - Base class for configuration options - """ - def __init__(self): - self._valid = False - self._config_file_value = None - self._config_dict_value = None - - @staticmethod - def config_file_key(): - """ - Key in the config file for this option - """ - raise NotImplementedError - - @classmethod - def prompt_title(cls): - """ - Title to print for prompt - """ - return cls.config_file_key() - - @classmethod - def prompt_message(cls): - """ - Message to print for prompt - """ - return None - - @classmethod - def visibility(cls): - return Visibility.DEFAULT - - def optional(self): - """ - If True, then this option can be set to None - """ - return False - - def suggestions(self): - """ - Return a list of Suggestions - """ - return [] - - def default_value(self, suggestions=None): - """ - Utility for retrieving the default value from the suggestion list - """ - if suggestions is None: - suggestions = self.suggestions() - for s in suggestions: - if s.default: - return s.value - return None - - @staticmethod - def is_path(): - """ - If True, tab autocompletion will be turned on during prompt - """ - return False - - @staticmethod - def has_test_value(): - """ - If true, use test_value during testing - """ - return False - - @staticmethod - def test_value(): - """ - Returns a special value to be used during testing - Ignores the current configuration - """ - raise NotImplementedError - - def valid(self): - """ - Returns True if this option has been set with a valid value - """ - return self._valid - - def has_value(self): - """ - Returns False if value is either None or '' - """ - return self.valid() and bool(self._config_file_value) - - @classmethod - def validate(cls, value): - """ - Returns a fixed-up valid version of value - Raises BadValue if invalid - """ - return value - - def set(self, value): - """ - Set the value - Raises BadValue - """ - value = self.validate(value) - self._config_file_value = value - self._set_config_dict_value(value) - self._valid = True - - def _set_config_dict_value(self, value): - """ - Set _config_dict_value according to a validated value - You may want to override this to store more detailed information - """ - self._config_dict_value = value - - def config_dict_value(self): - return self._config_dict_value - - def apply(self): - """ - Apply this configuration - (may involve altering the PATH) - """ - pass - -class FrameworkOption(Option): - """ - Base class for DL framework backends - """ - def optional(self): - return True - diff --git a/digits/config/current_config.py b/digits/config/current_config.py deleted file mode 100644 index e03c22200..000000000 --- a/digits/config/current_config.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -from .caffe_option import CaffeOption -from .extension_list import DataExtensionListOption, ViewExtensionListOption -from .gpu_list import GpuListOption -from .jobs_dir import JobsDirOption -from .log_file import LogFileOption -from .log_level import LogLevelOption -from .torch_option import TorchOption -from .server_name import ServerNameOption -from .secret_key import SecretKeyOption - -option_list = None - -def reset(): - """ - Reset option_list to a list of unset Options - """ - global option_list - - option_list = [ - JobsDirOption(), - GpuListOption(), - LogFileOption(), - LogLevelOption(), - ServerNameOption(), - SecretKeyOption(), - CaffeOption(), - TorchOption(), - DataExtensionListOption(), - ViewExtensionListOption(), - ] - -reset() - -def config_value(key): - """ - Return the current configuration value for the given option - - Arguments: - key -- the key of the configuration option - """ - for option in option_list: - if key == option.config_file_key(): - if not option.valid(): - raise RuntimeError('No valid value set for "%s"' % key) - return option.config_dict_value() - raise RuntimeError('No option found for "%s"' % key) - diff --git a/digits/config/edit.py b/digits/config/edit.py deleted file mode 100755 index d4a92e20b..000000000 --- a/digits/config/edit.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python2 -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import argparse -import os - -from . import config_file -from . import config_option -from . import current_config -from . import prompt - -def print_config(verbose=False): - """ - Prints out a matrix of config option values for each level - """ - min_visibility = config_option.Visibility.DEFAULT - if verbose: - min_visibility = config_option.Visibility.HIDDEN - - levels = [('INSTANCE', config_file.InstanceConfigFile())] - # filter out the files which don't exist - levels = [l for l in levels if l[1].can_read()] - - if len(levels) == 0: - # nothing to display - return None - - # create a row for each option - row_headers = [] - row_data = [] - for option in [o for o in current_config.option_list - if o.visibility() >= min_visibility]: - row_headers.append(option.prompt_title()) - row = [] - for title, config in levels: - value = config.get(option.config_file_key()) - row.append(prompt.value_to_str(value)) - row_data.append(row) - - prompt.print_section_header('Current Config') - - # calculate the width of each column for pretty printing - row_header_width = max([len(h) for h in row_headers]) - - row_data_widths = [] - for i, level in enumerate(levels): - title, config = level - w = len(title) - for row in row_data: - if len(row[i]) > w: - w = len(row[i]) - row_data_widths.append(w) - - # build the format string for printing - row_format = '%%%ds' % row_header_width - for width in row_data_widths: - row_format += ' | %%-%ds' % width - - # print header row - print row_format % (('',) + tuple([level[0] for level in levels])) - - # print option rows - for i, row in enumerate(row_data): - print row_format % ((row_headers[i],) + tuple(row)) - print - - -def edit_config_file(verbose=False): - """ - Prompt the user for which file to edit, - then allow them to set options in that file - """ - suggestions = [] - instanceConfig = config_file.InstanceConfigFile() - if instanceConfig.can_write(): - suggestions.append(prompt.Suggestion( - instanceConfig.filename(), 'I', - desc = 'Instance', default=True)) - - def filenameValidator(filename): - """ - Returns True if this is a valid file to edit - """ - if os.path.isfile(filename): - if not os.access(filename, os.W_OK): - raise config_option.BadValue('You do not have write permission') - else: - return filename - - if os.path.isdir(filename): - raise config_option.BadValue('This is a directory') - dirname = os.path.dirname(os.path.realpath(filename)) - if not os.path.isdir(dirname): - raise config_option.BadValue('Path not found: %s' % dirname) - elif not os.access(dirname, os.W_OK): - raise config_option.BadValue('You do not have write permission') - return filename - - filename = prompt.get_input( - message = 'Which file do you want to edit?', - suggestions = suggestions, - validator = filenameValidator, - is_path = True, - ) - - print 'Editing file at %s ...' % os.path.realpath(filename) - print - - is_standard_location = False - - if filename == instanceConfig.filename(): - is_standard_location = True - instanceConfig = None - - configFile = config_file.ConfigFile(filename) - - min_visibility = config_option.Visibility.DEFAULT - if verbose: - min_visibility = config_option.Visibility.HIDDEN - - # Loop through the visible options - for option in [o for o in current_config.option_list - if o.visibility() >= min_visibility]: - previous_value = configFile.get(option.config_file_key()) - suggestions = [prompt.Suggestion(None, 'U', - desc='unset', default=(previous_value is None))] - if previous_value is not None: - suggestions.append(prompt.Suggestion(previous_value, '', - desc = 'Previous', default = True)) - if instanceConfig is not None: - instance_value = instanceConfig.get(option.config_file_key()) - if instance_value is not None: - suggestions.append(prompt.Suggestion(instance_value, 'I', - desc = 'Instance', default = is_standard_location)) - suggestions += option.suggestions() - if option.optional(): - suggestions.append(prompt.Suggestion('', 'N', - desc = 'none', default = True)) - - prompt.print_section_header(option.prompt_title()) - value = prompt.get_input( - message = option.prompt_message(), - validator = option.validate, - suggestions = suggestions, - is_path = option.is_path(), - ) - print - configFile.set(option.config_file_key(), value) - - configFile.save() - print 'New config saved at %s' % configFile.filename() - print - print configFile - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Config - DIGITS') - parser.add_argument('-v', '--verbose', - action="store_true", - help='view more options') - - args = vars(parser.parse_args()) - - print_config(args['verbose']) - edit_config_file(args['verbose']) - diff --git a/digits/config/extension_list.py b/digits/config/extension_list.py index 3486280ec..67fd8bbf1 100644 --- a/digits/config/extension_list.py +++ b/digits/config/extension_list.py @@ -1,122 +1,15 @@ # Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import -from digits import extensions -from . import config_option -from . import prompt - - -class ExtensionListOption(config_option.Option): - """ - Common interface for extension list options - """ - @classmethod - def prompt_message(cls): - s = 'Available extensions:\n' - for extension in cls.get_extensions(True): - s += '\tID=\'%s\' Title=\'%s\'\n' % (extension.get_id(), extension.get_title()) - s += '\nInput the IDs of the extensions you would like to use, separated by commas.' - return s - - def optional(self): - return True - - def suggestions(self): - if len(self.get_extensions(False)) > 0: - return [ - prompt.Suggestion( - ','.join([ext.get_id() for ext in self.get_extensions(False)]), - 'D', desc='default', default=True), - prompt.Suggestion( - ','.join([ext.get_id() for ext in self.get_extensions(True)]), - 'A', desc='all', default=False), - ] - else: - return [] - - @classmethod - def visibility(cls): - if len(cls.get_extensions(True)) == 0: - # Nothing to see here - return config_option.Visibility.NEVER - else: - return config_option.Visibility.DEFAULT - - @classmethod - def validate(cls, value): - if value == '': - return value - - choices = [] - extensions = cls.get_extensions(True) - - if not len(extensions): - return '' - if len(extensions) and not value.strip(): - raise config_option.BadValue('Empty list') - for word in value.split(','): - if not word: - continue - if not cls.get_extension(word): - raise config_option.BadValue('There is no extension with ID=`%s`' % word) - if word in choices: - raise config_option.BadValue('You cannot select an extension twice') - choices.append(word) - - if len(choices) > 0: - return ','.join(choices) - else: - raise config_option.BadValue('Empty list') +import os - def _set_config_dict_value(self, value): - """ - Set _config_dict_value according to a validated value - """ - extensions = [] - for word in value.split(','): - extension = self.get_extension(word) - if extension is not None: - extensions.append(extension) - self._config_dict_value = extensions - - -class DataExtensionListOption(ExtensionListOption): - """ - Extension list sub-class for data extensions - """ - @staticmethod - def config_file_key(): - return 'data_extension_list' - - @classmethod - def prompt_title(cls): - return 'Data extensions' - - @classmethod - def get_extension(cls, extension_id): - return extensions.data.get_extension(extension_id) - - @classmethod - def get_extensions(cls, show_all): - return extensions.data.get_extensions(show_all=show_all) - - -class ViewExtensionListOption(ExtensionListOption): - """ - Extension list sub-class for data extensions - """ - @staticmethod - def config_file_key(): - return 'view_extension_list' - - @classmethod - def prompt_title(cls): - return 'View extensions' +from . import option_list +from digits import extensions - @classmethod - def get_extension(cls, extension_id): - return extensions.view.get_extension(extension_id) - @classmethod - def get_extensions(cls, show_all): - return extensions.view.get_extensions(show_all=show_all) +if 'DIGITS_ALL_EXTENSIONS' in os.environ: + option_list['data_extension_list'] = extensions.data.get_extensions(show_all=True) + option_list['view_extension_list'] = extensions.view.get_extensions(show_all=True) +else: + option_list['data_extension_list'] = extensions.data.get_extensions(show_all=False) + option_list['view_extension_list'] = extensions.view.get_extensions(show_all=False) diff --git a/digits/config/gpu_list.py b/digits/config/gpu_list.py index 13b065090..61c1e8131 100644 --- a/digits/config/gpu_list.py +++ b/digits/config/gpu_list.py @@ -1,91 +1,9 @@ # Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import -import math - -from . import config_option -from . import prompt +from . import option_list import digits.device_query -class GpuListOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'gpu_list' - - @classmethod - def prompt_title(cls): - return 'GPUs' - - @classmethod - def prompt_message(cls): - s = 'Attached devices:\n' - for device_id, gpu in enumerate(digits.device_query.get_devices()): - s += 'Device #%s:\n' % device_id - s += '\t%-20s %s\n' % ('Name', gpu.name) - s += '\t%-20s %s.%s\n' % ('Compute capability', gpu.major, gpu.minor) - s += '\t%-20s %s\n' % ('Memory', cls.convert_size(gpu.totalGlobalMem)) - s += '\t%-20s %s\n' % ('Multiprocessors', gpu.multiProcessorCount) - s += '\n' - return s + '\nInput the IDs of the devices you would like to use, separated by commas, in order of preference.' - - def optional(self): - return True - - def suggestions(self): - if len(digits.device_query.get_devices()) > 0: - return [prompt.Suggestion( - ','.join([str(x) for x in xrange(len(digits.device_query.get_devices()))]), - 'D', desc='default', default=True)] - else: - return [] - - @classmethod - def visibility(cls): - if len(digits.device_query.get_devices()) == 0: - # Nothing to see here - return config_option.Visibility.NEVER - else: - return config_option.Visibility.DEFAULT - - @classmethod - def validate(cls, value): - if value == '': - return value - - choices = [] - gpus = digits.device_query.get_devices() - - if not gpus: - return '' - if len(gpus) and not value.strip(): - raise config_option.BadValue('Empty list') - for word in value.split(','): - if not word: - continue - try: - num = int(word) - except ValueError as e: - raise config_option.BadValue(e.message) - - if not 0 <= num < len(gpus): - raise config_option.BadValue('There is no GPU #%d' % num) - if num in choices: - raise config_option.BadValue('You cannot select a GPU twice') - choices.append(num) - - if len(choices) > 0: - return ','.join(str(n) for n in choices) - else: - raise config_option.BadValue('Empty list') - @classmethod - def convert_size(cls, size): - size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") - i = int(math.floor(math.log(size,1024))) - p = math.pow(1024,i) - s = round(size/p,2) - if (s > 0): - return '%s %s' % (s,size_name[i]) - else: - return '0B' +option_list['gpu_list'] = ','.join([str(x) for x in xrange(len(digits.device_query.get_devices()))]) diff --git a/digits/config/jobs_dir.py b/digits/config/jobs_dir.py index 5c43a6f7b..732b0b1d9 100644 --- a/digits/config/jobs_dir.py +++ b/digits/config/jobs_dir.py @@ -4,57 +4,32 @@ import os import tempfile -from . import config_option -from . import prompt +from . import option_list import digits -class JobsDirOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'jobs_dir' - - @classmethod - def prompt_title(cls): - return 'Jobs Directory' - - @classmethod - def prompt_message(cls): - return 'Where would you like to store job data?' - - def suggestions(self): - d = os.path.join( - os.path.dirname(digits.__file__), - 'jobs') - return [prompt.Suggestion(d, 'D', desc='default', default=True)] - - @staticmethod - def is_path(): - return True - - @staticmethod - def has_test_value(): - return True - - @staticmethod - def test_value(): - return tempfile.mkdtemp() - - @classmethod - def validate(cls, value): - value = os.path.abspath(value) - if os.path.exists(value): - if not os.path.isdir(value): - raise config_option.BadValue('Is not a directory') - if not os.access(value, os.W_OK): - raise config_option.BadValue('You do not have write permission') - return value - if not os.path.exists(os.path.dirname(value)): - raise config_option.BadValue('Parent directory does not exist') - if not os.access(os.path.dirname(value), os.W_OK): - raise config_option.BadValue('You do not have write permission') - return value - - def apply(self): - if not os.path.exists(self._config_file_value): - # make the directory - os.mkdir(self._config_file_value) + +if 'DIGITS_MODE_TEST' in os.environ: + value = tempfile.mkdtemp() +elif 'DIGITS_JOBS_DIR' in os.environ: + value = os.environ['DIGITS_JOBS_DIR'] +else: + value = os.path.join(os.path.dirname(digits.__file__), 'jobs') + + +try: + value = os.path.abspath(value) + if os.path.exists(value): + if not os.path.isdir(value): + raise IOError('No such directory: "%s"' % value) + if not os.access(value, os.W_OK): + raise IOError('Permission denied: "%s"' % value) + if not os.path.exists(value): + os.makedirs(value) +except: + print '"%s" is not a valid value for jobs_dir.' % value + print 'Set the envvar DIGITS_JOBS_DIR to fix your configuration.' + raise + + +option_list['jobs_dir'] = value + diff --git a/digits/config/load.py b/digits/config/load.py deleted file mode 100644 index 537c3f962..000000000 --- a/digits/config/load.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import os - -from . import config_file -from . import config_option -from . import current_config -from . import prompt - -def load_option(option, mode, newConfig, - instanceConfig=None): - """ - Called from load_config() [below] - - Arguments: - option -- an Option instance - mode -- see docstring for load_config() - newConfig -- an instance of ConfigFile - instanceConfig -- the current InstanceConfigFile - """ - if 'DIGITS_MODE_TEST' in os.environ and option.has_test_value(): - option.set(option.test_value()) - return - - suggestions = [] - instance_value = instanceConfig.get(option.config_file_key()) - if instance_value is not None: - suggestions.append(prompt.Suggestion(instance_value, '', - desc = 'Previous', default = True)) - suggestions += option.suggestions() - if option.optional(): - suggestions.append(prompt.Suggestion('', 'N', - desc = 'none', default = True)) - - # Try to use the default values for options less than - # or equal to (LTE) this value - try_defaults_lte = config_option.Visibility.DEFAULT - - if mode == 'verbose': - try_defaults_lte = config_option.Visibility.NEVER - elif mode == 'normal': - try_defaults_lte = config_option.Visibility.HIDDEN - elif mode == 'quiet': - pass - elif mode == 'force': - pass - else: - raise config_option.BadValue('Unknown mode "%s"' % mode) - - valid = False - if option.visibility() <= try_defaults_lte: - # check for a valid default value - for s in [s for s in suggestions if s.default]: - try: - option.set(s.value) - valid = True - break - except config_option.BadValue as e: - print 'Default value for %s "%s" invalid:' % (option.config_file_key(), s.value) - print '\t%s' % e - if not valid: - if mode == 'force': - raise RuntimeError('No valid default value found for configuration option "%s"' % option.config_file_key()) - else: - # prompt user for value - prompt.print_section_header(option.prompt_title()) - value = prompt.get_input( - message = option.prompt_message(), - validator = option.validate, - suggestions = suggestions, - is_path = option.is_path(), - ) - print - option.set(value) - newConfig.set(option.config_file_key(), option._config_file_value) - -def load_config(mode='force'): - """ - Load the current config - By default, the user is prompted for values which have not been set already - - Keyword arguments: - mode -- 3 options: - verbose -- prompt for all options - (`python -m digits.config.edit --verbose`) - normal -- accept defaults for hidden options, otherwise prompt - (`digits-devserver --config`, `python -m digits.config.edit`) - quiet -- prompt only for options without valid defaults - (`digits-devserver`) - force -- throw errors for invalid options - (`digits-server`, `digits-test`) - """ - current_config.reset() - - instanceConfig = config_file.InstanceConfigFile() - newConfig = config_file.InstanceConfigFile() - - non_framework_options = [o for o in current_config.option_list - if not isinstance(o, config_option.FrameworkOption)] - framework_options = [o for o in current_config.option_list - if isinstance(o, config_option.FrameworkOption)] - - # Load non-framework config options - for option in non_framework_options: - load_option(option, mode, newConfig, instanceConfig) - - has_one_framework = False - verbose_for_frameworks = False - while not has_one_framework: - # Load framework config options - if verbose_for_frameworks and mode == 'quiet': - framework_mode = 'verbose' - else: - framework_mode = mode - for option in framework_options: - load_option(option, framework_mode, newConfig, instanceConfig) - if option.has_value(): - has_one_framework = True - - if not has_one_framework: - errstr = 'DIGITS requires at least one DL backend to run.' - if mode == 'force': - raise RuntimeError(errstr) - else: - print errstr - # try again prompting all - verbose_for_frameworks = True - - for option in current_config.option_list: - option.apply() - - if newConfig.dirty() and newConfig.can_write(): - newConfig.save() - print 'Saved config to %s' % newConfig.filename() - diff --git a/digits/config/log_file.py b/digits/config/log_file.py index 12552a0ee..28929c896 100644 --- a/digits/config/log_file.py +++ b/digits/config/log_file.py @@ -1,87 +1,73 @@ # Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import +import logging import os +import sys -from . import config_option -from . import prompt +from . import option_list import digits -class LogFileOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'log_file' - @classmethod - def prompt_title(cls): - return 'Log File' - - @classmethod - def prompt_message(cls): - return 'Where do you want the log files to be stored?' - - def optional(self): - # if not set, no log will be saved - return True - - def suggestions(self): - suggested_dir = os.path.dirname(digits.__file__) - - if os.access(suggested_dir, os.W_OK): - return [prompt.Suggestion( - os.path.join(suggested_dir, 'digits.log'), 'D', - desc='default', default=True) - ] +def load_logfile_filename(): + """ + Return the configured log file or None + Throws an exception only if a manually specified log file is invalid + """ + throw_error = False + if 'DIGITS_MODE_TEST' in os.environ: + filename = None + elif 'DIGITS_LOGFILE_FILENAME' in os.environ: + filename = os.environ['DIGITS_LOGFILE_FILENAME'] + throw_error = True + else: + filename = os.path.join(os.path.dirname(digits.__file__), 'digits.log') + + + if filename is not None: + try: + filename = os.path.abspath(filename) + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + os.makedirs(os.path.dirname(filename)) + with open(filename, 'w'): + pass + except: + if throw_error: + print '"%s" is not a valid value for logfile_filename.' % filename + print 'Set the envvar DIGITS_LOGFILE_FILENAME to fix your configuration.' + raise + else: + filename = None + return filename + + +def load_logfile_level(): + """ + Return the configured logging level, or throw an exception + """ + if 'DIGITS_MODE_TEST' in os.environ: + return logging.DEBUG + elif 'DIGITS_LOGFILE_LEVEL' in os.environ: + level = os.environ['DIGITS_LOGFILE_LEVEL'].strip().lower() + if level == 'debug': + return logging.DEBUG + elif level == 'info': + return logging.INFO + elif level == 'warning': + return logging.WARNING + elif level == 'error': + return logging.ERROR + elif level == 'critical': + return logging.CRITICAL else: - return [] - - @staticmethod - def is_path(): - return True - - @staticmethod - def has_test_value(): - return True - - @staticmethod - def test_value(): - return None - - @classmethod - def validate(cls, value): - if not value: - return value - value = os.path.abspath(value) - dirname = os.path.dirname(value) - - if os.path.isfile(value): - if not os.access(value, os.W_OK): - raise config_option.BadValue('You do not have write permissions') - if not os.access(dirname, os.W_OK): - raise config_option.BadValue('You do not have write permissions for "%s"' % dirname) - return value - elif os.path.isdir(value): - raise config_option.BadValue('"%s" is a directory' % value) - else: - if os.path.isdir(dirname): - if not os.access(dirname, os.W_OK): - raise config_option.BadValue('You do not have write permissions for "%s"' % dirname) - # filename is in a valid directory - return value - previous_dir = os.path.dirname(dirname) - if not os.path.isdir(previous_dir): - raise config_option.BadValue('"%s" not found' % value) - if not os.access(previous_dir, os.W_OK): - raise config_option.BadValue('You do not have write permissions for "%s"' % previous_dir) - # the preceding directory can be created later (in apply()) - return value - - def apply(self): - if not self._config_file_value: - return + raise ValueError('Invalid value "%s" for logfile_level. Set DIGITS_LOGFILE_LEVEL to fix your configuration.' % level) + else: + return logging.INFO - dirname = os.path.dirname(self._config_file_value) - if not os.path.exists(dirname): - os.mkdir(dirname) +option_list['log_file'] = { + 'filename': load_logfile_filename(), + 'level': load_logfile_level(), +} diff --git a/digits/config/log_level.py b/digits/config/log_level.py deleted file mode 100644 index 544ffdbc9..000000000 --- a/digits/config/log_level.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -from . import config_option -from . import prompt - -class LogLevelOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'log_level' - - @classmethod - def prompt_title(cls): - return 'Log Level' - - @classmethod - def prompt_message(cls): - return 'What is the minimum log level that you want to save to your logfile? [error/warning/info/debug]' - - @classmethod - def visibility(cls): - return config_option.Visibility.HIDDEN - - def suggestions(self): - return [ - prompt.Suggestion('debug', 'D'), - prompt.Suggestion('info', 'I', default=True), - prompt.Suggestion('warning', 'W'), - prompt.Suggestion('error', 'E'), - ] - - @classmethod - def validate(cls, value): - value = value.strip().lower() - if value not in ['error', 'warning', 'info', 'debug']: - raise config_option.BadValue - return value - diff --git a/digits/config/prompt.py b/digits/config/prompt.py deleted file mode 100644 index 956b3f01f..000000000 --- a/digits/config/prompt.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -""" -Classes and functions relating to prompting a user for configuration options -""" -from __future__ import absolute_import - -import os.path -import readline -import sys - -from . import config_option - -def print_section_header(title): - """ - Utility for printing a section header - """ - print '{s:{c}^{n}}'.format( - s = ' %s ' % title, - # Extend to 80 characters - n = 80, c = '=') - -def value_to_str(value): - if value is None: - return '' - elif type(value) is not str: - return str(value) - elif not value.strip(): - return '' - else: - return value - -class Suggestion(object): - """ - A simple class for Option suggested values (used in get_input()) - """ - def __init__(self, value, char, - desc = None, - default = False, - ): - """ - Arguments: - value -- the suggested value - char -- a 1 character token representing this suggestion - - Keyword arguments: - desc -- a short description of the source of this suggestion - default -- if True, this is the suggestion that will be accepted by default - """ - self.value = value - if not isinstance(char, str): - raise ValueError('char must be a string') - if not (char == '' or len(char) == 1): - raise ValueError('char must be a single character') - self.char = char - self.desc = desc - self.default = default - - def __str__(self): - s = ' max_width: - max_width = len(s.desc) - if max_width > 0: - print '\tSuggested values:' - format_str = '\t%%-4s %%-%ds %%s' % (max_width+2,) - default_found = False - for s in suggestions: - c = s.char - if s.default and not default_found: - default_found = True - c += '*' - desc = '' - if s.desc is not None: - desc = '[%s]' % s.desc - print format_str % (('(%s)' % c), desc, value_to_str(s.value)) - - if is_path: - # turn on filename autocompletion - delims = readline.get_completer_delims() - readline.set_completer_delims(' \t\n;') - readline.parse_and_bind('TAB: complete') - - user_input = None - value = None - valid = False - while not valid: - try: - # Get user input - user_input = raw_input('>> ').strip() - except (KeyboardInterrupt, EOFError): - print - sys.exit(0) - - if user_input == '': - for s in suggestions: - if s.default: - print 'Using "%s"' % s.value - if s.value is not None and validator is not None: - try: - value = validator(s.value) - valid = True - break - except config_option.BadValue as e: - print 'ERROR:', e - else: - value = s.value - valid = True - break - else: - if len(user_input) == 1: - for s in suggestions: - if s.char.lower() == user_input.lower(): - print 'Using "%s"' % s.value - if s.value is not None and validator is not None: - try: - value = validator(s.value) - valid = True - break - except config_option.BadValue as e: - print 'ERROR:', e - else: - value = s.value - valid = True - break - if not valid and validator is not None: - if is_path: - user_input = os.path.expanduser(user_input) - try: - value = validator(user_input) - valid = True - print 'Using "%s"' % value - except config_option.BadValue as e: - print 'ERROR:', e - - if not valid: - print 'Invalid input' - - if is_path: - # back to normal - readline.set_completer_delims(delims) - readline.parse_and_bind('TAB: ') - - return value diff --git a/digits/config/secret_key.py b/digits/config/secret_key.py deleted file mode 100644 index 90e574969..000000000 --- a/digits/config/secret_key.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import os - -from . import config_option -from . import prompt - -class SecretKeyOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'secret_key' - - @classmethod - def visibility(cls): - return config_option.Visibility.NEVER - - def suggestions(self): - key = os.urandom(12).encode('hex') - return [prompt.Suggestion(key, 'D', desc='default', default=True)] - diff --git a/digits/config/server_name.py b/digits/config/server_name.py index d565d6726..26fa6e3bd 100644 --- a/digits/config/server_name.py +++ b/digits/config/server_name.py @@ -1,29 +1,15 @@ # Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import +import os import platform -from . import config_option -from . import prompt +from . import option_list -class ServerNameOption(config_option.Option): - @staticmethod - def config_file_key(): - return 'server_name' - - @classmethod - def prompt_title(cls): - return 'Server Name' - - @classmethod - def visibility(cls): - return config_option.Visibility.HIDDEN - - def optional(self): - return True - - def suggestions(self): - hostname = platform.node() - return [prompt.Suggestion(hostname, 'H', desc='HOSTNAME')] +if 'DIGITS_SERVER_NAME' in os.environ: + value = os.environ['DIGITS_SERVER_NAME'] +else: + value = platform.node() +option_list['server_name'] = value diff --git a/digits/config/test_config_file.py b/digits/config/test_config_file.py deleted file mode 100644 index 2b56da224..000000000 --- a/digits/config/test_config_file.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import tempfile - -from . import config_file - -class TestConfigFile(): - def test_write_and_read(self): - for args in [ - ('name', 'value'), - ('blank', ''), - ]: - yield self.check_val, args - - def check_val(self, args): - name, value = args - - filename = None - with tempfile.NamedTemporaryFile(suffix='cfg') as tmp: - filename = tmp.name - cf1 = config_file.ConfigFile(filename) - assert not cf1.exists(), 'tempfile already exists' - assert cf1.can_write(), "can't write to tempfile" - - cf1.set(name, value) - cf1.save() - - cf2 = config_file.ConfigFile(filename) - assert cf2.exists(), "tempfile doesn't exist" - assert cf2.can_read(), "can't read from tempfile" - - cf2.load() - assert cf2.get(name) == value, \ - '"%s" is "%s", not "%s"' % (name, cf2.get(name), value) - diff --git a/digits/config/test_prompt.py b/digits/config/test_prompt.py deleted file mode 100644 index 3dd422d4c..000000000 --- a/digits/config/test_prompt.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -from contextlib import contextmanager -import sys - -# Find the best implementation available -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO - -import mock -from nose.tools import raises - -from . import prompt - -class TestValueToStr(): - def test_none(self): - # pass none to value_to_str - assert prompt.value_to_str(None) == '', 'passing None should return an empty string' - - def test_nonstring(self): - # pass a non-string value to value_to_str - assert prompt.value_to_str(1) == '1', 'passing 1 should return the string "1"' - -class TestSuggestion(): - @raises(ValueError) - def test_new_bad_char_type(self): - # pass a non-string type as char to suggestion - prompt.Suggestion(None, 1) - - @raises(ValueError) - def test_new_bad_multichar(self): - # pass multiple chars where one is expected - prompt.Suggestion(None, 'badvalue') - - def test_str_method(self): - # test __str__ method of Suggestion - suggestion = prompt.Suggestion('alpha', 'a', 'test', True) - strval = str(suggestion) - expect = '' - - assert strval == expect, 'Suggestion is not producing the correct string value %s' % expect - -@contextmanager -def mockInput(fn): - original = __builtins__['raw_input'] - __builtins__['raw_input'] = fn - yield - __builtins__['raw_input'] = original - -class TestGetInput(): - def setUp(self): - self.suggestions = [prompt.Suggestion('alpha', 'a', 'test', False)] - - @raises(SystemExit) - def test_get_input_sys_exit(self): - # bad input from user - def temp(_): - raise KeyboardInterrupt - - with mockInput(temp): - prompt.get_input('Test', lambda _: True, self.suggestions) - - def test_get_input_empty_then_full(self): - # test both major paths of get_input - # Python 2 does not have the 'nonlocal' keyword, so we fudge the closure with an object. - class Temp: - def __init__(self): - self.flag = False - def __call__(self, _): - if not self.flag: - self.flag = True - return '' - else: - return 'a' - - with mockInput(Temp()): - assert prompt.get_input('Test', lambda x: x, self.suggestions) == 'alpha', 'get_input should return "alpha" for input "a"' - - def test_get_input_empty_default(self): - # empty input should choose the default - self.suggestions[0].default = True - - with mockInput(lambda _: ''): - assert prompt.get_input('Test', lambda x: x+'_validated', self.suggestions) == 'alpha_validated', 'get_input should return the default value "alpha"' - - def test_get_input_empty_default_no_validator(self): - # empty input should choose the default and not validate - self.suggestions[0].default = True - - with mockInput(lambda _: ''): - assert prompt.get_input('Test', suggestions=self.suggestions) == 'alpha', 'get_input should return the default value "alpha"' - - @mock.patch('os.path.expanduser') - def test_get_input_path(self, mock_expanduser): - # should correctly validate path - mock_expanduser.side_effect = lambda x: '/path'+x - - with mockInput(lambda _: '/test'): - assert prompt.get_input(validator=lambda x: x, is_path=True) == '/path/test', 'get_input should return the default value "alpha"' - diff --git a/digits/config/torch.py b/digits/config/torch.py new file mode 100644 index 000000000..ff590f6b8 --- /dev/null +++ b/digits/config/torch.py @@ -0,0 +1,56 @@ +# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. +from __future__ import absolute_import + +import os + +from . import option_list + + +def find_executable(path=None): + """ + Finds th on the given path and returns it if found + If path is None, searches through PATH + """ + if path is None: + dirnames = os.environ['PATH'].split(os.pathsep) + suffixes = ['th'] + else: + dirnames = [path] + # fuzzy search + suffixes = ['th', + os.path.join('bin', 'th'), + os.path.join('install', 'bin', 'th')] + + for dirname in dirnames: + dirname = dirname.strip('"') + for suffix in suffixes: + path = os.path.join(dirname, suffix) + if os.path.isfile(path) and os.access(path, os.X_OK): + return path + return None + + +if 'TORCH_ROOT' in os.environ: + executable = find_executable(os.environ['TORCH_ROOT']) + if executable is None: + raise ValueError('Torch executable not found at "%s" (TORCH_ROOT)' + % os.environ['TORCH_ROOT']) +elif 'TORCH_HOME' in os.environ: + executable = find_executable(os.environ['TORCH_HOME']) + if executable is None: + raise ValueError('Torch executable not found at "%s" (TORCH_HOME)' + % os.environ['TORCH_HOME']) +else: + executable = find_executable() + + +if executable is None: + option_list['torch'] = { + 'enabled': False, + } +else: + option_list['torch'] = { + 'enabled': True, + 'executable': executable, + } + diff --git a/digits/config/torch_option.py b/digits/config/torch_option.py deleted file mode 100644 index faae1dd21..000000000 --- a/digits/config/torch_option.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved. -from __future__ import absolute_import - -import os - -from . import config_option -from . import prompt - -class TorchOption(config_option.FrameworkOption): - @staticmethod - def config_file_key(): - return 'torch_root' - - @classmethod - def prompt_title(cls): - return 'Torch' - - @classmethod - def prompt_message(cls): - return 'Where is torch installed?' - - def optional(self): - return True - - def suggestions(self): - suggestions = [] - if 'TORCH_ROOT' in os.environ: - d = os.environ['TORCH_ROOT'] - try: - suggestions.append(prompt.Suggestion( - self.validate(d), 'R', - desc='TORCH_ROOT', default=True)) - except config_option.BadValue as e: - print 'TORCH_ROOT "%s" is invalid:' % d - print '\t%s' % e - if 'TORCH_HOME' in os.environ: - d = os.environ['TORCH_HOME'] - try: - default = True - if len(suggestions) > 0: - default = False - suggestions.append(prompt.Suggestion( - self.validate(d), 'H', - desc='TORCH_HOME', default=default)) - except config_option.BadValue as e: - print 'TORCH_HOME "%s" is invalid:' % d - print '\t%s' % e - suggestions.append(prompt.Suggestion('', 'P', - desc='PATH/TORCHPATH', default=True)) - return suggestions - - @staticmethod - def is_path(): - return True - - @classmethod - def validate(cls, value): - if not value: - return value - - if value == '': - # Find the executable - executable = cls.find_executable('th') - if not executable: - raise config_option.BadValue('torch binary not found in PATH') - #cls.validate_version(executable) - return value - else: - # Find the executable - value = os.path.abspath(value) - if not os.path.isdir(value): - raise config_option.BadValue('"%s" is not a directory' % value) - expected_path = os.path.join(value, 'bin', 'th') - if not os.path.exists(expected_path): - raise config_option.BadValue('torch binary not found at "%s"' % value) - #cls.validate_version(expected_path) - return value - - @staticmethod - def find_executable(program): - """ - Finds an executable by searching through PATH - Returns the path to the executable or None - """ - for path in os.environ['PATH'].split(os.pathsep): - path = path.strip('"') - executable = os.path.join(path, program) - if os.path.isfile(executable) and os.access(executable, os.X_OK): - return executable - return None - - @classmethod - def validate_version(cls, executable): - """ - Utility for checking the caffe version from within validate() - Throws BadValue - - Arguments: - executable -- path to a caffe executable - """ - # Currently DIGITS don't have any restrictions on Torch version, so no need to implement this. - pass - - def apply(self): - pass diff --git a/digits/dataset/images/generic/test_lmdb_creator.py b/digits/dataset/images/generic/test_lmdb_creator.py index 17c712683..43a534209 100755 --- a/digits/dataset/images/generic/test_lmdb_creator.py +++ b/digits/dataset/images/generic/test_lmdb_creator.py @@ -25,10 +25,9 @@ if __name__ == '__main__': dirname = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, os.path.join(dirname,'..','..','..','..')) - from digits.config.load import load_config - load_config() + import digits.config -# Run load_config() first to set the path to Caffe +# Import digits.config first to set the path to Caffe import caffe_pb2 diff --git a/digits/frameworks/__init__.py b/digits/frameworks/__init__.py index c55137e85..f3072d584 100644 --- a/digits/frameworks/__init__.py +++ b/digits/frameworks/__init__.py @@ -11,7 +11,7 @@ # # torch is optional -torch = TorchFramework() if config_value('torch_root') else None +torch = TorchFramework() if config_value('torch')['enabled'] else None # caffe is mandatory caffe = CaffeFramework() diff --git a/digits/frameworks/caffe_framework.py b/digits/frameworks/caffe_framework.py index 7574165b2..eeb6c7c58 100644 --- a/digits/frameworks/caffe_framework.py +++ b/digits/frameworks/caffe_framework.py @@ -34,13 +34,13 @@ class CaffeFramework(Framework): # whether this framework can shuffle data during training CAN_SHUFFLE_DATA = False - if config_value('caffe_root')['flavor'] == 'NVIDIA': - if config_value('caffe_root')['version'] > parse_version('0.14.0-alpha'): + if config_value('caffe')['flavor'] == 'NVIDIA': + if parse_version(config_value('caffe')['version']) > parse_version('0.14.0-alpha'): SUPPORTED_SOLVER_TYPES = ['SGD', 'NESTEROV', 'ADAGRAD', 'RMSPROP', 'ADADELTA', 'ADAM'] else: SUPPORTED_SOLVER_TYPES = ['SGD', 'NESTEROV', 'ADAGRAD'] - elif config_value('caffe_root')['flavor'] == 'BVLC': + elif config_value('caffe')['flavor'] == 'BVLC': SUPPORTED_SOLVER_TYPES = ['SGD', 'NESTEROV', 'ADAGRAD', 'RMSPROP', 'ADADELTA', 'ADAM'] else: @@ -144,10 +144,11 @@ def get_network_visualization(self, desc): @override def can_accumulate_gradients(self): - if config_value('caffe_root')['flavor'] == 'BVLC': + if config_value('caffe')['flavor'] == 'BVLC': return True - elif config_value('caffe_root')['flavor'] == 'NVIDIA': - return config_value('caffe_root')['version'] > parse_version('0.14.0-alpha') + elif config_value('caffe')['flavor'] == 'NVIDIA': + return (parse_version(config_value('caffe')['version']) + > parse_version('0.14.0-alpha')) else: raise ValueError('Unknown flavor. Support NVIDIA and BVLC flavors only.') diff --git a/digits/frameworks/torch_framework.py b/digits/frameworks/torch_framework.py index 59b7af5e3..e385290ce 100644 --- a/digits/frameworks/torch_framework.py +++ b/digits/frameworks/torch_framework.py @@ -125,10 +125,7 @@ def get_network_visualization(self, desc): try: # do this in a try..finally clause to make sure we delete the temp file # build command line - if config_value('torch_root') == '': - torch_bin = 'th' - else: - torch_bin = os.path.join(config_value('torch_root'), 'bin', 'th') + torch_bin = config_value('torch')['executable'] args = [torch_bin, os.path.join(os.path.dirname(os.path.dirname(digits.__file__)),'tools','torch','main.lua'), diff --git a/digits/log.py b/digits/log.py index 1e24ee4b4..26f5d71e4 100644 --- a/digits/log.py +++ b/digits/log.py @@ -70,27 +70,20 @@ def setup_logging(): ### digits.webapp logger - if config_value('log_file'): + logfile_filename = config_value('log_file')['filename'] + logfile_level = config_value('log_file')['level'] + + if logfile_filename is not None: webapp_logger = logging.getLogger('digits.webapp') webapp_logger.setLevel(logging.DEBUG) # Log to file fileHandler = logging.handlers.RotatingFileHandler( - config_value('log_file'), + logfile_filename, maxBytes=(1024*1024*10), # 10 MB backupCount=10, ) fileHandler.setFormatter(formatter) - level = config_value('log_level') - if level == 'debug': - fileHandler.setLevel(logging.DEBUG) - elif level == 'info': - fileHandler.setLevel(logging.INFO) - elif level == 'warning': - fileHandler.setLevel(logging.WARNING) - elif level == 'error': - fileHandler.setLevel(logging.ERROR) - elif level == 'critical': - fileHandler.setLevel(logging.CRITICAL) + fileHandler.setLevel(logfile_level) webapp_logger.addHandler(fileHandler) ### Useful shortcut for the webapp, which may set job_id diff --git a/digits/model/images/classification/test_views.py b/digits/model/images/classification/test_views.py index 4a0b22c57..8fb864b29 100644 --- a/digits/model/images/classification/test_views.py +++ b/digits/model/images/classification/test_views.py @@ -102,7 +102,7 @@ class BaseViewsTest(digits.test_views.BaseViewsTest): @classmethod def setUpClass(cls): super(BaseViewsTest, cls).setUpClass() - if cls.FRAMEWORK=='torch' and not config_value('torch_root'): + if cls.FRAMEWORK=='torch' and not config_value('torch')['enabled']: raise unittest.SkipTest('Torch not found') @classmethod @@ -342,10 +342,10 @@ def test_snapshot_interval_0_5(self): not config_value('gpu_list'), 'no GPUs selected') @unittest.skipIf( - not config_value('caffe_root')['cuda_enabled'], + not config_value('caffe')['cuda_enabled'], 'CUDA disabled') @unittest.skipIf( - config_value('caffe_root')['multi_gpu'], + config_value('caffe')['multi_gpu'], 'multi-GPU enabled') def test_select_gpu(self): for index in config_value('gpu_list').split(','): @@ -359,10 +359,10 @@ def check_select_gpu(self, gpu_index): not config_value('gpu_list'), 'no GPUs selected') @unittest.skipIf( - not config_value('caffe_root')['cuda_enabled'], + not config_value('caffe')['cuda_enabled'], 'CUDA disabled') @unittest.skipIf( - not config_value('caffe_root')['multi_gpu'], + not config_value('caffe')['multi_gpu'], 'multi-GPU disabled') def test_select_gpus(self): # test all possible combinations @@ -797,8 +797,8 @@ def test_inference_while_training(self): # get number of GPUs gpu_count = 1 if (config_value('gpu_list') and - config_value('caffe_root')['cuda_enabled'] and - config_value('caffe_root')['multi_gpu']): + config_value('caffe')['cuda_enabled'] and + config_value('caffe')['multi_gpu']): gpu_count = len(config_value('gpu_list').split(',')) # grab an image for testing diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index 7fffeffbf..4037defab 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -81,7 +81,7 @@ def new(): previous_network_snapshots = prev_network_snapshots, previous_networks_fullinfo = get_previous_networks_fulldetails(), pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(), - multi_gpu = config_value('caffe_root')['multi_gpu'], + multi_gpu = config_value('caffe')['multi_gpu'], ) @blueprint.route('.json', methods=['POST']) @@ -115,7 +115,7 @@ def create(): previous_network_snapshots = prev_network_snapshots, previous_networks_fullinfo = get_previous_networks_fulldetails(), pretrained_networks_fullinfo = get_pretrained_networks_fulldetails(), - multi_gpu = config_value('caffe_root')['multi_gpu'], + multi_gpu = config_value('caffe')['multi_gpu'], ), 400 datasetJob = scheduler.get_job(form.dataset.data) @@ -240,7 +240,7 @@ def create(): raise werkzeug.exceptions.BadRequest( 'Invalid learning rate policy') - if config_value('caffe_root')['multi_gpu']: + if config_value('caffe')['multi_gpu']: if form.select_gpus.data: selected_gpus = [str(gpu) for gpu in form.select_gpus.data] gpu_count = None diff --git a/digits/model/images/generic/test_views.py b/digits/model/images/generic/test_views.py index 1b0747fdc..09f265805 100644 --- a/digits/model/images/generic/test_views.py +++ b/digits/model/images/generic/test_views.py @@ -103,7 +103,7 @@ class BaseViewsTest(digits.test_views.BaseViewsTest): @classmethod def setUpClass(cls, **kwargs): super(BaseViewsTest, cls).setUpClass(**kwargs) - if cls.FRAMEWORK == 'torch' and not config_value('torch_root'): + if cls.FRAMEWORK == 'torch' and not config_value('torch')['enabled']: raise unittest.SkipTest('Torch not found') @classmethod @@ -322,10 +322,10 @@ def test_snapshot_interval_0_5(self): not config_value('gpu_list'), 'no GPUs selected') @unittest.skipIf( - not config_value('caffe_root')['cuda_enabled'], + not config_value('caffe')['cuda_enabled'], 'CUDA disabled') @unittest.skipIf( - config_value('caffe_root')['multi_gpu'], + config_value('caffe')['multi_gpu'], 'multi-GPU enabled') def test_select_gpu(self): for index in config_value('gpu_list').split(','): @@ -339,10 +339,10 @@ def check_select_gpu(self, gpu_index): not config_value('gpu_list'), 'no GPUs selected') @unittest.skipIf( - not config_value('caffe_root')['cuda_enabled'], + not config_value('caffe')['cuda_enabled'], 'CUDA disabled') @unittest.skipIf( - not config_value('caffe_root')['multi_gpu'], + not config_value('caffe')['multi_gpu'], 'multi-GPU disabled') def test_select_gpus(self): # test all possible combinations diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py index e74569cf9..b35077c56 100644 --- a/digits/model/images/generic/views.py +++ b/digits/model/images/generic/views.py @@ -50,7 +50,7 @@ def new(extension_id=None): previous_network_snapshots=prev_network_snapshots, previous_networks_fullinfo=get_previous_networks_fulldetails(), pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(), - multi_gpu=config_value('caffe_root')['multi_gpu'], + multi_gpu=config_value('caffe')['multi_gpu'], ) @@ -89,7 +89,7 @@ def create(extension_id=None): previous_network_snapshots=prev_network_snapshots, previous_networks_fullinfo=get_previous_networks_fulldetails(), pretrained_networks_fullinfo=get_pretrained_networks_fulldetails(), - multi_gpu=config_value('caffe_root')['multi_gpu'], + multi_gpu=config_value('caffe')['multi_gpu'], ), 400 datasetJob = scheduler.get_job(form.dataset.data) @@ -203,7 +203,7 @@ def create(extension_id=None): raise werkzeug.exceptions.BadRequest( 'Invalid learning rate policy') - if config_value('caffe_root')['multi_gpu']: + if config_value('caffe')['multi_gpu']: if form.select_gpu_count.data: gpu_count = form.select_gpu_count.data selected_gpus = None diff --git a/digits/model/tasks/caffe_train.py b/digits/model/tasks/caffe_train.py index dd6c2a8ae..b8387ebdd 100644 --- a/digits/model/tasks/caffe_train.py +++ b/digits/model/tasks/caffe_train.py @@ -135,8 +135,8 @@ def __init__(self, **kwargs): self.log_file = self.CAFFE_LOG self.digits_version = digits.__version__ - self.caffe_version = config_value('caffe_root')['ver_str'] - self.caffe_flavor = config_value('caffe_root')['flavor'] + self.caffe_version = config_value('caffe')['version'] + self.caffe_flavor = config_value('caffe')['flavor'] def __getstate__(self): state = super(CaffeTrainTask, self).__getstate__() @@ -502,7 +502,7 @@ def save_files_classification(self): solver.net = self.train_val_file # Set CPU/GPU mode - if config_value('caffe_root')['cuda_enabled'] and \ + if config_value('caffe')['cuda_enabled'] and \ bool(config_value('gpu_list')): solver.solver_mode = caffe_pb2.SolverParameter.GPU else: @@ -727,7 +727,7 @@ def save_files_generic(self): solver.net = self.train_val_file # Set CPU/GPU mode - if config_value('caffe_root')['cuda_enabled'] and \ + if config_value('caffe')['cuda_enabled'] and \ bool(config_value('gpu_list')): solver.solver_mode = caffe_pb2.SolverParameter.GPU else: @@ -903,7 +903,7 @@ def task_arguments(self, resources, env): # Not in Windows, or in Windows but no Python Layer # This is the normal path - args = [config_value('caffe_root')['executable'], + args = [config_value('caffe')['executable'], 'train', '--solver=%s' % self.path(self.solver_file), ] @@ -915,13 +915,14 @@ def task_arguments(self, resources, env): if len(identifiers) == 1: args.append('--gpu=%s' % identifiers[0]) elif len(identifiers) > 1: - if config_value('caffe_root')['flavor'] == 'NVIDIA': - if config_value('caffe_root')['version'] < utils.parse_version('0.14.0-alpha'): + if config_value('caffe')['flavor'] == 'NVIDIA': + if (utils.parse_version(config_value('caffe')['version']) + < utils.parse_version('0.14.0-alpha')): # Prior to version 0.14, NVcaffe used the --gpus switch args.append('--gpus=%s' % ','.join(identifiers)) else: args.append('--gpu=%s' % ','.join(identifiers)) - elif config_value('caffe_root')['flavor'] == 'BVLC': + elif config_value('caffe')['flavor'] == 'BVLC': args.append('--gpu=%s' % ','.join(identifiers)) else: raise ValueError('Unknown flavor. Support NVIDIA and BVLC flavors only.') diff --git a/digits/model/tasks/torch_train.py b/digits/model/tasks/torch_train.py index b6c86353f..acecf353f 100644 --- a/digits/model/tasks/torch_train.py +++ b/digits/model/tasks/torch_train.py @@ -136,15 +136,10 @@ def create_mean_file(self): @override def task_arguments(self, resources, env): - if config_value('torch_root') == '': - torch_bin = 'th' - else: - torch_bin = os.path.join(config_value('torch_root'), 'bin', 'th') - dataset_backend = self.dataset.get_backend() assert dataset_backend=='lmdb' or dataset_backend=='hdf5' - args = [torch_bin, + args = [config_value('torch')['executable'], os.path.join(os.path.dirname(os.path.dirname(digits.__file__)),'tools','torch','wrapper.lua'), 'main.lua', '--network=%s' % self.model_file.split(".")[0], @@ -527,14 +522,9 @@ def infer_one_image(self, image, snapshot_epoch=None, layers=None, gpu=None): self.logger.error(error_message) raise digits.inference.errors.InferenceError(error_message) - if config_value('torch_root') == '': - torch_bin = 'th' - else: - torch_bin = os.path.join(config_value('torch_root'), 'bin', 'th') - file_to_load = self.get_snapshot(snapshot_epoch) - args = [torch_bin, + args = [config_value('torch')['executable'], os.path.join(os.path.dirname(os.path.dirname(digits.__file__)),'tools','torch','wrapper.lua'), 'test.lua', '--image=%s' % temp_image_path, @@ -830,14 +820,9 @@ def infer_many_images(self, images, snapshot_epoch=None, gpu=None): os.close(temp_image_handle) os.close(temp_imglist_handle) - if config_value('torch_root') == '': - torch_bin = 'th' - else: - torch_bin = os.path.join(config_value('torch_root'), 'bin', 'th') - file_to_load = self.get_snapshot(snapshot_epoch) - args = [torch_bin, + args = [config_value('torch')['executable'], os.path.join(os.path.dirname(os.path.dirname(digits.__file__)),'tools','torch','wrapper.lua'), 'test.lua', '--testMany=yes', diff --git a/digits/webapp.py b/digits/webapp.py index 101b238af..6f80c989a 100644 --- a/digits/webapp.py +++ b/digits/webapp.py @@ -1,7 +1,7 @@ # Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved. from __future__ import absolute_import -import os.path +import os import flask from flask.ext.socketio import SocketIO @@ -19,7 +19,7 @@ # Disable CSRF checking in WTForms app.config['WTF_CSRF_ENABLED'] = False # This is still necessary for SocketIO -app.config['SECRET_KEY'] = config_value('secret_key') +app.config['SECRET_KEY'] = os.urandom(12).encode('hex') app.url_map.redirect_defaults = False socketio = SocketIO(app, async_mode='gevent') scheduler = digits.scheduler.Scheduler(config_value('gpu_list'), True) @@ -28,8 +28,8 @@ app.jinja_env.globals['server_name'] = config_value('server_name') app.jinja_env.globals['server_version'] = digits.__version__ -app.jinja_env.globals['caffe_version'] = config_value('caffe_root')['ver_str'] -app.jinja_env.globals['caffe_flavor'] = config_value('caffe_root')['flavor'] +app.jinja_env.globals['caffe_version'] = config_value('caffe')['version'] +app.jinja_env.globals['caffe_flavor'] = config_value('caffe')['flavor'] app.jinja_env.globals['dir_hash'] = fs.dir_hash( os.path.join(os.path.dirname(digits.__file__), 'static')) app.jinja_env.filters['print_time'] = utils.time_filters.print_time diff --git a/docs/BuildCaffe.md b/docs/BuildCaffe.md index 66f358afd..685d8a128 100644 --- a/docs/BuildCaffe.md +++ b/docs/BuildCaffe.md @@ -21,29 +21,29 @@ sudo apt-get install --no-install-recommends build-essential cmake git gfortran ```sh # example location - can be customized -export CAFFE_HOME=~/caffe -git clone https://github.com/NVIDIA/caffe.git $CAFFE_HOME +export CAFFE_ROOT=~/caffe +git clone https://github.com/NVIDIA/caffe.git $CAFFE_ROOT ``` -Setting the `CAFFE_HOME` environment variable will help DIGITS automatically detect your Caffe installation, but this is optional. +Setting the `CAFFE_ROOT` environment variable will help DIGITS automatically detect your Caffe installation, but this is optional. ## Python packages Several PyPI packages need to be installed: ```sh -sudo pip install -r $CAFFE_HOME/python/requirements.txt +sudo pip install -r $CAFFE_ROOT/python/requirements.txt ``` If you hit some errors about missing imports, then use this command to install the packages in order ([see discussion here](https://github.com/BVLC/caffe/pull/1950#issuecomment-76026969)): ```sh -cat $CAFFE_HOME/python/requirements.txt | xargs -n1 sudo pip install +cat $CAFFE_ROOT/python/requirements.txt | xargs -n1 sudo pip install ``` ## Build We recommend using CMake to configure Caffe rather than the raw Makefile build for automatic dependency detection: ```sh -cd $CAFFE_HOME +cd $CAFFE_ROOT mkdir build cd build cmake .. diff --git a/docs/BuildDigits.md b/docs/BuildDigits.md index e115a3997..c7e5b8b5d 100644 --- a/docs/BuildDigits.md +++ b/docs/BuildDigits.md @@ -27,17 +27,17 @@ Follow [these instructions](BuildTorch.md) to build Torch7 (*suggested*). ```sh # example location - can be customized -DIGITS_HOME=~/digits -git clone https://github.com/NVIDIA/DIGITS.git $DIGITS_HOME +DIGITS_ROOT=~/digits +git clone https://github.com/NVIDIA/DIGITS.git $DIGITS_ROOT ``` -Throughout the docs, we'll refer to your install location as `DIGITS_HOME` (`~/digits` in this case), though you don't need to actually set that environment variable. +Throughout the docs, we'll refer to your install location as `DIGITS_ROOT` (`~/digits` in this case), though you don't need to actually set that environment variable. ## Python packages Several PyPI packages need to be installed: ```sh -sudo pip install -r $DIGITS_HOME/requirements.txt +sudo pip install -r $DIGITS_ROOT/requirements.txt ``` # Starting the server diff --git a/docs/BuildTorch.md b/docs/BuildTorch.md index 624af986a..187c6ab28 100644 --- a/docs/BuildTorch.md +++ b/docs/BuildTorch.md @@ -21,11 +21,10 @@ sudo apt-get install --no-install-recommends git software-properties-common These instructions are based on [the official Torch instructions](http://torch.ch/docs/getting-started.html). ```sh # example location - can be customized -export TORCH_BUILD=~/torch -export TORCH_HOME=$TORCH_BUILD/install +export TORCH_ROOT=~/torch -git clone https://github.com/torch/distro.git $TORCH_BUILD --recursive -cd $TORCH_BUILD +git clone https://github.com/torch/distro.git $TORCH_ROOT --recursive +cd $TORCH_ROOT ./install-deps ./install.sh -b source ~/.bashrc diff --git a/docs/Configuration.md b/docs/Configuration.md new file mode 100644 index 000000000..d259c7d42 --- /dev/null +++ b/docs/Configuration.md @@ -0,0 +1,19 @@ +# Configuration + +DIGITS uses environment variables for configuration. +The code for reading these variables and setting the configuration are at [digits/config/](../digits/config/). + +> NOTE: Prior to https://github.com/NVIDIA/DIGITS/pull/1091 (up to DIGITS 4.0), DIGITS used configuration files instead. + + +## Environment Variables + +| Variable | Example value | Description | +| --- | --- | --- | +| `DIGITS_JOBS_DIR` | ~/digits-jobs | Location where job files are stored. Default is `$DIGITS_ROOT/digits/jobs`. | +| `CAFFE_ROOT` | ~/caffe | Path to your local Caffe build. Should contain `build/tools/caffe` and `python/caffe/`. If unset, looks for caffe in PATH and PYTHONPATH.| +| `TORCH_ROOT` | ~/torch | Path to your local Torch build. Should contain `install/bin/th`. If unset, looks for th in PATH. | +| `DIGITS_ALL_EXTENSIONS` | 1 | Show all extensions, even those that are hidden by default. | +| `DIGITS_LOGFILE_FILENAME` | ~/digits.log | File for saving log messages. Default is `$DIGITS_ROOT/digits/digits.log`. | +| `DIGITS_LOGFILE_LEVEL` | DEBUG | Minimum log message level to be saved (DEBUG/INFO/WARNING/ERROR/CRITICAL). Default is INFO. | +| `DIGITS_SERVER_NAME` | The Big One | The name of the server (accessible in the UI under "Info"). Default is the system hostname. | diff --git a/docs/GettingStarted.md b/docs/GettingStarted.md index 2b3c79512..54391996a 100644 --- a/docs/GettingStarted.md +++ b/docs/GettingStarted.md @@ -10,7 +10,7 @@ Both are generously made available by Yann LeCun on [his website](http://yann.le Use the following command to download the MNIST dataset onto your server (for Deb package installations, the script is at `/usr/share/digits/tools/download_data/main.py`): ```sh -$ $DIGITS_HOME/tools/download_data/main.py mnist ~/mnist +$ $DIGITS_ROOT/tools/download_data/main.py mnist ~/mnist Downloading url=http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ... Downloading url=http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ... Downloading url=http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ... diff --git a/examples/fine-tuning/README.md b/examples/fine-tuning/README.md index bb0bdcd4b..6ae1aeca2 100644 --- a/examples/fine-tuning/README.md +++ b/examples/fine-tuning/README.md @@ -26,7 +26,7 @@ We will need to create a new dataset that comprises two classes: one for images Assuming you have a folder containing MNIST images, you may use the `create_dataset.sh` script to create the required directory structure into a folder called `odd_or_even_dataset`: ```sh -$ cd $DIGITS_HOME/examples/fine-tuning +$ cd $DIGITS_ROOT/examples/fine-tuning $ ./create_dataset.sh odd_or_even_dataset /train ``` diff --git a/examples/object-detection/README.md b/examples/object-detection/README.md index db9dc04bf..aa9d17595 100644 --- a/examples/object-detection/README.md +++ b/examples/object-detection/README.md @@ -35,7 +35,7 @@ Left color images of object data set | `data_object_image_2.zip` | **12GB** Training labels of object data set | `data_object_label_2.zip` | 5MB Object development kit | `devkit_object.zip` | 1MB -Copy those files into `$DIGITS_HOME/examples/object-detection/`. +Copy those files into `$DIGITS_ROOT/examples/object-detection/`. Then, use the `prepare_kitti_data.py` script to create a train/val split of the labelled images. This will take a few minutes, spent mostly on unpacking the large zipfiles. @@ -51,7 +51,7 @@ Creating train/val split ... Done. ``` -At the end you will have your data at `$DIGITS_HOME/examples/object-detection/kitti-data/{train,val}/`. +At the end you will have your data at `$DIGITS_ROOT/examples/object-detection/kitti-data/{train,val}/`. The data is structured in the following way: - An image folder containing supported images (`.png`, `.jpg`, etc.). @@ -86,7 +86,7 @@ DetectNet is a GoogLeNet-derived network that is specifically tuned for Object D For more information on DetectNet, please refer to [this blog post](https://devblogs.nvidia.com/parallelforall/detectnet-deep-neural-network-object-detection-digits/). In order to train DetectNet, [NVcaffe](https://github.com/NVIDIA/caffe) version [0.15.1](https://github.com/NVIDIA/caffe/tree/v0.15.1) or later is required. -The model description for DetectNet can be found at `$CAFFE_HOME/examples/kitti/detectnet_network.prototxt` ([raw link](https://raw.githubusercontent.com/NVIDIA/caffe/caffe-0.15/examples/kitti/detectnet_network.prototxt)). +The model description for DetectNet can be found at `$CAFFE_ROOT/examples/kitti/detectnet_network.prototxt` ([raw link](https://raw.githubusercontent.com/NVIDIA/caffe/caffe-0.15/examples/kitti/detectnet_network.prototxt)). Since DetectNet is derived from GoogLeNet it is strongly recommended to use pre-trained weights from an ImageNet-trained GoogLeNet as this will help speed training up significantly. A suitable pre-trained GoogLeNet `.caffemodel` may be found on this [page](https://github.com/BVLC/caffe/tree/rc3/models/bvlc_googlenet). diff --git a/examples/semantic-segmentation/README.md b/examples/semantic-segmentation/README.md index c25dd463a..5e3b28b7d 100644 --- a/examples/semantic-segmentation/README.md +++ b/examples/semantic-segmentation/README.md @@ -96,7 +96,7 @@ On the DIGITS home page, click `New Dataset > Images > Segmentation`: In the dataset creation form, click `Separate validation images` then specify the paths to the image and label folders for each of the training and validation sets. -In `Class Labels` specify the path to `$DIGITS_HOME/examples/semantic-segmentation/pascal-voc-classes.txt`. +In `Class Labels` specify the path to `$DIGITS_ROOT/examples/semantic-segmentation/pascal-voc-classes.txt`. This will allow DIGITS to print class names during inference. In `Label Encoding` select `PNG (lossless)`. diff --git a/examples/siamese/README.md b/examples/siamese/README.md index f8365da4d..2237b4f55 100644 --- a/examples/siamese/README.md +++ b/examples/siamese/README.md @@ -38,7 +38,7 @@ The first step in creating the dataset is to create the LMDB databases. In this To create a train database of 100000 pairs of images into a folder called `siamesedb`: ```sh -$ cd $DIGITS_HOME/examples/siamese +$ cd $DIGITS_ROOT/examples/siamese $ create_db.py siamesedb ../../digits/jobs/20151111-210842-a4ec/train.txt -c 100000 ``` The script also creates a validation database of 1000 samples. Overall, the script creates: diff --git a/examples/siamese/create_db.py b/examples/siamese/create_db.py index a611cf0da..a072f42ba 100755 --- a/examples/siamese/create_db.py +++ b/examples/siamese/create_db.py @@ -26,12 +26,11 @@ if __name__ == '__main__': dirname = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, os.path.join(dirname,'..','..')) - from digits.config.load import load_config - load_config() + import digits.config from digits import utils -# Run load_config() first to set the path to Caffe +# Import digits.config first to set the path to Caffe import caffe.io import caffe_pb2 diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index c164a53bf..d9d8f3d85 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -29,7 +29,7 @@ The following sample is an example from the "company" class: The first step to creating the dataset is to convert the `.csv` files to a format that DIGITS can use: ```sh -$ cd $DIGITS_HOME/examples/text-classification +$ cd $DIGITS_ROOT/examples/text-classification $ ./create_dataset.py $DBPEDIA/dbpedia_csv/train.csv dbpedia/train --labels $DBPEDIA/dbpedia_csv/classes.txt --create-images ``` diff --git a/gunicorn_config.py b/gunicorn_config.py index de41839d8..1bbcf4b7d 100644 --- a/gunicorn_config.py +++ b/gunicorn_config.py @@ -10,10 +10,6 @@ bind = '0.0.0.0:34448' # DIGIT loglevel = 'debug' -def on_starting(server): - from digits import config - config.load_config() - def post_fork(server, worker): from digits.webapp import scheduler scheduler.start() diff --git a/tools/analyze_db.py b/tools/analyze_db.py index 32cd3ae6e..442438aab 100755 --- a/tools/analyze_db.py +++ b/tools/analyze_db.py @@ -21,10 +21,9 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import log -# Run load_config() first to set path to Caffe +# Import digits.config first to set path to Caffe import caffe_pb2 logger = logging.getLogger('digits.tools.analyze_db') diff --git a/tools/create_db.py b/tools/create_db.py index 83ee50a19..50cfd0b4e 100755 --- a/tools/create_db.py +++ b/tools/create_db.py @@ -28,10 +28,9 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import utils, log -# must call digits.config.load_config() before caffe to set the path +# Import digits.config first to set the path to Caffe import caffe.io import caffe_pb2 diff --git a/tools/create_generic_db.py b/tools/create_generic_db.py index 12ee67a16..bd0ec0fec 100755 --- a/tools/create_generic_db.py +++ b/tools/create_generic_db.py @@ -19,11 +19,10 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import extensions, log from digits.job import Job -# Run load_config() first to set the path to Caffe +# Import digits.config first to set the path to Caffe import caffe.io import caffe_pb2 diff --git a/tools/inference.py b/tools/inference.py index a008b51f7..a6cb19e4b 100755 --- a/tools/inference.py +++ b/tools/inference.py @@ -17,13 +17,12 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import utils, log from digits.inference.errors import InferenceError from digits.job import Job from digits.utils.lmdbreader import DbReader -# must call digits.config.load_config() before caffe to set the path +# Import digits.config before caffe to set the path import caffe.io import caffe_pb2 diff --git a/tools/parse_folder.py b/tools/parse_folder.py index 89fc8ea6a..b129eec6e 100755 --- a/tools/parse_folder.py +++ b/tools/parse_folder.py @@ -14,7 +14,6 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import utils, log logger = logging.getLogger('digits.tools.parse_folder') diff --git a/tools/resize_image.py b/tools/resize_image.py index 690943ad4..0b9d6851a 100755 --- a/tools/resize_image.py +++ b/tools/resize_image.py @@ -11,7 +11,6 @@ # Add path for DIGITS package sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import digits.config -digits.config.load_config() from digits import utils, log logger = logging.getLogger('digits.tools.resize_image')