Skip to content

Commit

Permalink
Only remove all imports if told to do so
Browse files Browse the repository at this point in the history
  • Loading branch information
myint committed Apr 19, 2017
1 parent 0774950 commit ce723cb
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 16 deletions.
36 changes: 20 additions & 16 deletions autoflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,27 @@ def filter_unused_import(line, unused_module, remove_all_unused_imports,
"""Return line if used, otherwise return None."""
if multiline_import(line, previous_line):
return line
elif ',' in line:
if re.match(r'^\s*from\s', line):
return filter_from_import(line, unused_module)
else:
return break_up_import(line)

is_from_import = re.match(r'^\s*from\s', line)

if ',' in line and not is_from_import:
return break_up_import(line)

package = extract_package_name(line)
if not remove_all_unused_imports and package not in imports:
return line

if ',' in line:
assert is_from_import
return filter_from_import(line, unused_module)
else:
package = extract_package_name(line)
if not remove_all_unused_imports and package not in imports:
return line
else:
# We need to replace import with "pass" in case the import is the
# only line inside a block. For example,
# "if True:\n import os". In such cases, if the import is
# removed, the block will be left hanging with no body.
return (get_indentation(line) +
'pass' +
get_line_ending(line))
# We need to replace import with "pass" in case the import is the
# only line inside a block. For example,
# "if True:\n import os". In such cases, if the import is
# removed, the block will be left hanging with no body.
return (get_indentation(line) +
'pass' +
get_line_ending(line))


def filter_unused_variable(line, previous_line=''):
Expand Down
14 changes: 14 additions & 0 deletions test_autoflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,20 @@ def test_fix_code_with_from_and_as(self):
from collections import defaultdict as abc, namedtuple as xyz
"""))

def test_fix_code_with_from_with_and_without_remove_all(self):
code = """\
from x import a as b, c as d
"""

self.assertEqual(
"""\
""",
autoflake.fix_code(code, remove_all_unused_imports=True))

self.assertEqual(
code,
autoflake.fix_code(code, remove_all_unused_imports=False))

def test_fix_code_with_from_and_depth_module(self):
self.assertEqual(
"""\
Expand Down

0 comments on commit ce723cb

Please sign in to comment.