diff --git a/autoflake.py b/autoflake.py index 77a5916..70a5aa5 100755 --- a/autoflake.py +++ b/autoflake.py @@ -331,7 +331,9 @@ def filter_code(source, additional_imports=None, expand_star_imports=False, remove_all_unused_imports=False, remove_duplicate_keys=False, - remove_unused_variables=False): + remove_unused_variables=False, + ignore_init_module_imports=False, + ): """Yield code with unused imports removed.""" imports = SAFE_IMPORTS if additional_imports: @@ -340,8 +342,11 @@ def filter_code(source, additional_imports=None, messages = check(source) - marked_import_line_numbers = frozenset( - unused_import_line_numbers(messages)) + if ignore_init_module_imports: + marked_import_line_numbers = frozenset() + else: + marked_import_line_numbers = frozenset( + unused_import_line_numbers(messages)) marked_unused_module = collections.defaultdict(lambda: []) for line_number, module_name in unused_import_module_name(messages): marked_unused_module[line_number].append(module_name) @@ -589,7 +594,7 @@ def get_line_ending(line): def fix_code(source, additional_imports=None, expand_star_imports=False, remove_all_unused_imports=False, remove_duplicate_keys=False, - remove_unused_variables=False): + remove_unused_variables=False, ignore_init_module_imports=False): """Return code with all filtering run on it.""" if not source: return source @@ -608,7 +613,9 @@ def fix_code(source, additional_imports=None, expand_star_imports=False, expand_star_imports=expand_star_imports, remove_all_unused_imports=remove_all_unused_imports, remove_duplicate_keys=remove_duplicate_keys, - remove_unused_variables=remove_unused_variables)))) + remove_unused_variables=remove_unused_variables, + ignore_init_module_imports=ignore_init_module_imports, + )))) if filtered_source == source: break @@ -625,13 +632,20 @@ def fix_file(filename, args, standard_out): original_source = source + if args.ignore_init_module_imports and filename.endswith('__init__.py'): + ignore_init_module_imports = True + else: + ignore_init_module_imports = False + filtered_source = fix_code( source, additional_imports=args.imports.split(',') if args.imports else None, expand_star_imports=args.expand_star_imports, remove_all_unused_imports=args.remove_all_unused_imports, remove_duplicate_keys=args.remove_duplicate_keys, - remove_unused_variables=args.remove_unused_variables) + remove_unused_variables=args.remove_unused_variables, + ignore_init_module_imports=ignore_init_module_imports, + ) if original_source != filtered_source: if args.in_place: @@ -791,6 +805,9 @@ def _main(argv, standard_out, standard_error): parser.add_argument('--remove-all-unused-imports', action='store_true', help='remove all unused imports (not just those from ' 'the standard library)') + parser.add_argument('--ignore-init-module-imports', action='store_true', + help='exclude __init__.py when removing unused ' + 'imports') parser.add_argument('--remove-duplicate-keys', action='store_true', help='remove all duplicate keys in objects') parser.add_argument('--remove-unused-variables', action='store_true', diff --git a/test_autoflake.py b/test_autoflake.py index 919a81a..8cdd710 100755 --- a/test_autoflake.py +++ b/test_autoflake.py @@ -501,6 +501,41 @@ def foo(): """ self.assertEqual(line, ''.join(autoflake.filter_code(line))) + def test_with_ignore_init_module_imports_flag(self): + # Need a temp directory in order to specify file name as __init__.py + temp_directory = tempfile.mkdtemp(dir='.') + temp_file = os.path.join(temp_directory, '__init__.py') + try: + with open(temp_file, 'w') as output: + output.write('import re\n') + + p = subprocess.Popen( + list(AUTOFLAKE_COMMAND) + + ['--ignore-init-module-imports', temp_file], + stdout=subprocess.PIPE) + result = p.communicate()[0].decode('utf-8') + + self.assertNotIn('import re', result) + finally: + shutil.rmtree(temp_directory) + + def test_without_ignore_init_module_imports_flag(self): + # Need a temp directory in order to specify file name as __init__.py + temp_directory = tempfile.mkdtemp(dir='.') + temp_file = os.path.join(temp_directory, '__init__.py') + try: + with open(temp_file, 'w') as output: + output.write('import re\n') + + p = subprocess.Popen( + list(AUTOFLAKE_COMMAND) + [temp_file], + stdout=subprocess.PIPE) + result = p.communicate()[0].decode('utf-8') + + self.assertIn('import re', result) + finally: + shutil.rmtree(temp_directory) + def test_fix_code(self): self.assertEqual( """\