From 1424609de947f7ffd865728f95d03fb30d519737 Mon Sep 17 00:00:00 2001 From: chenglu Date: Sun, 25 Oct 2020 13:31:45 +0800 Subject: [PATCH 1/4] Add geting help message from docstring --- pytorch_lightning/utilities/argparse_utils.py | 29 +++++++++-- tests/utilities/test_argparse_utils.py | 50 +++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 tests/utilities/test_argparse_utils.py diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 57c9e23d80dc9..d310bfebedbf2 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -14,7 +14,7 @@ import inspect import os from argparse import ArgumentParser, Namespace -from typing import Union, List, Tuple, Any +from typing import Dict, Union, List, Tuple, Any from pytorch_lightning.utilities import parsing @@ -160,7 +160,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: allowed_types = (str, int, float, bool) - # TODO: get "help" from docstring :) + args_help = parse_args_from_docstring(cls.__doc__ or cls.__init__.__doc__) for arg, arg_types, arg_default in ( at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names ): @@ -200,13 +200,36 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: dest=arg, default=arg_default, type=use_type, - help='autogenerated by pl.Trainer', + help=args_help.get(arg), **arg_kwargs, ) return parser +def parse_args_from_docstring(docstring: str) -> Dict[str, str]: + arg_block_indent = None + current_arg = None + parsed = {} + for line in docstring.split("\n"): + stripped = line.lstrip() + if not stripped: + continue + line_indent = len(line) - len(stripped) + if stripped.startswith('Args:'): + arg_block_indent = line_indent + 4 + elif arg_block_indent is None: + continue + elif line_indent < arg_block_indent: + break + elif line_indent == arg_block_indent: + current_arg, arg_description = stripped.split(':', maxsplit=1) + parsed[current_arg] = arg_description.lstrip() + elif line_indent > arg_block_indent: + parsed[current_arg] = f'{parsed[current_arg]} {stripped}' + return parsed + + def _gpus_allowed_type(x) -> Union[int, str]: if ',' in x: return str(x) diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse_utils.py new file mode 100644 index 0000000000000..066dad6bddb6a --- /dev/null +++ b/tests/utilities/test_argparse_utils.py @@ -0,0 +1,50 @@ +from pytorch_lightning.utilities.argparse_utils import parse_args_from_docstring + + +def test_parse_args_from_docstring_normal(): + args_help = parse_args_from_docstring( + """Constrain image dataset + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + num_samples: number of examples per selected class/digit + digits: list selected MNIST digits/classes + + Examples: + >>> dataset = TrialMNIST(download=True) + >>> len(dataset) + 300 + >>> sorted(set([d.item() for d in dataset.targets])) + [0, 1, 2] + >>> torch.bincount(dataset.targets) + tensor([100, 100, 100]) + """ + ) + + expected_args = ['root', 'train', 'normalize', 'download', 'num_samples', 'digits'] + assert len(args_help.keys()) == len(expected_args) + assert all([x == y for x, y in zip(args_help.keys(), expected_args)]) + assert args_help['root'] == 'Root directory of dataset where ``MNIST/processed/training.pt``' \ + ' and ``MNIST/processed/test.pt`` exist.' + assert args_help['normalize'] == 'mean and std deviation of the MNIST dataset.' + + +def test_parse_args_from_docstring_empty(): + args_help = parse_args_from_docstring( + """Constrain image dataset + + Args: + + Returns: + + Examples: + """ + ) + assert len(args_help.keys()) == 0 From 280681193f90d53497283f75b69ad49800f0e4a2 Mon Sep 17 00:00:00 2001 From: chenglu Date: Sun, 25 Oct 2020 14:27:26 +0800 Subject: [PATCH 2/4] Fix pep8 issue --- tests/utilities/test_argparse_utils.py | 56 +++++++++++++------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/utilities/test_argparse_utils.py b/tests/utilities/test_argparse_utils.py index 066dad6bddb6a..978ad820482b2 100644 --- a/tests/utilities/test_argparse_utils.py +++ b/tests/utilities/test_argparse_utils.py @@ -3,29 +3,29 @@ def test_parse_args_from_docstring_normal(): args_help = parse_args_from_docstring( - """Constrain image dataset - - Args: - root: Root directory of dataset where ``MNIST/processed/training.pt`` - and ``MNIST/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - normalize: mean and std deviation of the MNIST dataset. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - num_samples: number of examples per selected class/digit - digits: list selected MNIST digits/classes - - Examples: - >>> dataset = TrialMNIST(download=True) - >>> len(dataset) - 300 - >>> sorted(set([d.item() for d in dataset.targets])) - [0, 1, 2] - >>> torch.bincount(dataset.targets) - tensor([100, 100, 100]) - """ + """Constrain image dataset + + Args: + root: Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + normalize: mean and std deviation of the MNIST dataset. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + num_samples: number of examples per selected class/digit + digits: list selected MNIST digits/classes + + Examples: + >>> dataset = TrialMNIST(download=True) + >>> len(dataset) + 300 + >>> sorted(set([d.item() for d in dataset.targets])) + [0, 1, 2] + >>> torch.bincount(dataset.targets) + tensor([100, 100, 100]) + """ ) expected_args = ['root', 'train', 'normalize', 'download', 'num_samples', 'digits'] @@ -38,13 +38,13 @@ def test_parse_args_from_docstring_normal(): def test_parse_args_from_docstring_empty(): args_help = parse_args_from_docstring( - """Constrain image dataset + """Constrain image dataset - Args: + Args: - Returns: + Returns: - Examples: - """ + Examples: + """ ) assert len(args_help.keys()) == 0 From 08ebfea084170ae586dc9f08c0e77f5cc980b18e Mon Sep 17 00:00:00 2001 From: Chenglu Date: Sun, 25 Oct 2020 17:02:10 +0800 Subject: [PATCH 3/4] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/argparse_utils.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a534c6bfaf40..08e2e93b93d9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162)) +- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) + + ### Changed diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index d310bfebedbf2..51b4f1c888928 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -216,7 +216,7 @@ def parse_args_from_docstring(docstring: str) -> Dict[str, str]: if not stripped: continue line_indent = len(line) - len(stripped) - if stripped.startswith('Args:'): + if stripped.startswith(('Args:', 'Arguments:', 'Parameters:')): arg_block_indent = line_indent + 4 elif arg_block_indent is None: continue From d3a033a409e61b9ef728aec05a23f4f42e3bced5 Mon Sep 17 00:00:00 2001 From: Chenglu Date: Sun, 25 Oct 2020 18:05:07 +0800 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pytorch_lightning/utilities/argparse_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 51b4f1c888928..f3cf2e5f1b90d 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -160,7 +160,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: allowed_types = (str, int, float, bool) - args_help = parse_args_from_docstring(cls.__doc__ or cls.__init__.__doc__) + args_help = parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__) for arg, arg_types, arg_default in ( at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names ): @@ -226,7 +226,7 @@ def parse_args_from_docstring(docstring: str) -> Dict[str, str]: current_arg, arg_description = stripped.split(':', maxsplit=1) parsed[current_arg] = arg_description.lstrip() elif line_indent > arg_block_indent: - parsed[current_arg] = f'{parsed[current_arg]} {stripped}' + parsed[current_arg] += f' {stripped}' return parsed