Skip to content

Commit

Permalink
fix gpus default for Trainer.add_argparse_args (#6898)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 9, 2021
1 parent aaccbee commit 9c9e2a0
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))


- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def add_argparse_args(

if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default

# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
Expand Down Expand Up @@ -287,10 +286,6 @@ def _gpus_allowed_type(x) -> Union[int, str]:
return int(x)


def _gpus_arg_default(x) -> Union[int, str]:
return _gpus_allowed_type(x)


def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
Expand Down
9 changes: 0 additions & 9 deletions pytorch_lightning/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
If no GPUs are available but the value of gpus variable indicates request for GPUs
then a MisconfigurationException is raised.
"""

# nothing was passed into the GPUs argument
if callable(gpus):
return None

# Check that gpus param is None, Int, String or List
_check_data_type(gpus)

Expand Down Expand Up @@ -97,10 +92,6 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
Returns:
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
"""

if callable(tpu_cores):
return None

_check_data_type(tpu_cores)

if isinstance(tpu_cores, str):
Expand Down
4 changes: 1 addition & 3 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import os
import pickle
import types
from argparse import ArgumentParser
from unittest import mock

Expand Down Expand Up @@ -172,11 +171,10 @@ def wrapper_something():
params.wrapper_something_wo_name = lambda: lambda: '1'
params.wrapper_something = wrapper_something

assert isinstance(params.gpus, types.FunctionType)
params = WandbLogger._convert_params(params)
params = WandbLogger._flatten_dict(params)
params = WandbLogger._sanitize_callable_params(params)
assert params["gpus"] == '_gpus_arg_default'
assert params["gpus"] == "None"
assert params["something"] == "something"
assert params["wrapper_something"] == "wrapper_something"
assert params["wrapper_something_wo_name"] == "<lambda>"
Expand Down
12 changes: 7 additions & 5 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,21 +175,23 @@ def test_argparse_args_parsing(cli_args, expected):
assert Trainer.from_argparse_args(args)


@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
pytest.param('--gpus 1', [0]),
pytest.param('--gpus 0,', [0]),
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
pytest.param('', None, None),
pytest.param('--gpus 1', 1, [0]),
pytest.param('--gpus 0,', '0,', [0]),
])
@RunIf(min_gpus=1)
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
def test_argparse_args_parsing_gpus(cli_args, expected_parsed, expected_device_ids):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

assert args.gpus == expected_parsed
trainer = Trainer.from_argparse_args(args)
assert trainer.data_parallel_device_ids == expected_gpu
assert trainer.data_parallel_device_ids == expected_device_ids


@RunIf(min_python="3.7.0")
Expand Down
8 changes: 4 additions & 4 deletions tests/utilities/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import (
_gpus_arg_default,
_gpus_allowed_type,
_int_or_float_type,
add_argparse_args,
from_argparse_args,
Expand Down Expand Up @@ -205,9 +205,9 @@ def test_add_argparse_args_no_argument_group():
assert args.my_parameter == 2


def test_gpus_arg_default():
assert _gpus_arg_default('1,2') == '1,2'
assert _gpus_arg_default('1') == 1
def test_gpus_allowed_type():
assert _gpus_allowed_type('1,2') == '1,2'
assert _gpus_allowed_type('1') == 1


def test_int_or_float_type():
Expand Down

0 comments on commit 9c9e2a0

Please sign in to comment.