Skip to content

Commit

Permalink
ref: separate argparse (#3428)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Sep 9, 2020
1 parent f7dac3f commit 9696484
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 106 deletions.
107 changes: 1 addition & 106 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,112 +488,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)

blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, int, float, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in argparse_utils.get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus' or arg == 'tpu_cores':
use_type = Trainer._gpus_allowed_type
arg_default = Trainer._gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = Trainer._int_or_float_type

# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)

return parser

def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)
return argparse_utils.add_argparse_args(cls, parent_parser)

@property
def num_gpus(self) -> int:
Expand Down
114 changes: 114 additions & 0 deletions pytorch_lightning/utilities/argparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
Expand Down Expand Up @@ -107,3 +108,116 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
name_type_default.append((arg, arg_types, arg_default))

return name_type_default


def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> import pprint
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)

blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, int, float, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = _int_or_float_type

# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)

return parser


def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)


def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)

0 comments on commit 9696484

Please sign in to comment.