diff --git a/ansible_runner/__main__.py b/ansible_runner/__main__.py index 228d5b241..ff9925df3 100644 --- a/ansible_runner/__main__.py +++ b/ansible_runner/__main__.py @@ -432,6 +432,15 @@ logger = logging.getLogger('ansible-runner') +class AnsibleRunnerArgumentParser(argparse.ArgumentParser): + def error(self, message): + # If no sub command was provided, print common usage then exit + if 'required: command' in message.lower(): + print_common_usage() + + super(AnsibleRunnerArgumentParser, self).error(message) + + @contextmanager def role_manager(vargs): if vargs.get('role'): @@ -583,7 +592,7 @@ def main(sys_args=None): :rtype: SystemExit """ - parser = argparse.ArgumentParser( + parser = AnsibleRunnerArgumentParser( prog='ansible-runner', description="Use 'ansible-runner' (with no arguments) to see basic usage" ) @@ -747,11 +756,6 @@ def main(sys_args=None): add_args_to_parser(isalive_container_group, DEFAULT_CLI_ARGS['container_group']) add_args_to_parser(transmit_container_group, DEFAULT_CLI_ARGS['container_group']) - if len(sys.argv) == 1: - parser.print_usage() - print_common_usage() - parser.exit(status=0) - args = parser.parse_args(sys_args) vargs = vars(args) diff --git a/test/integration/test_main.py b/test/integration/test_main.py index 34cd96439..5f0aea353 100644 --- a/test/integration/test_main.py +++ b/test/integration/test_main.py @@ -70,10 +70,26 @@ def will_pass(): assert not os.path.exists(context['saved_temp_dir']) -def test_help(): +@pytest.mark.parametrize( + ('command', 'expected'), + ( + (None, {'out': 'These are common Ansible Runner commands', 'err': ''}), + ([], {'out': 'These are common Ansible Runner commands', 'err': ''}), + (['run'], {'out': '', 'err': 'the following arguments are required'}), + ) +) +def test_help(command, expected, capsys, monkeypatch): + # Ensure that sys.argv of the test command does not affect the test environment. + monkeypatch.setattr('sys.argv', command or []) + with pytest.raises(SystemExit) as exc: - main([]) + main(command) + + stdout, stderr = capsys.readouterr() + assert exc.value.code == 2, 'Should raise SystemExit with return code 2' + assert expected['out'] in stdout + assert expected['err'] in stderr def test_module_run():