Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiline import statements (proposal for issue #8) #63

Merged
merged 13 commits into from
Aug 23, 2020
Merged
225 changes: 202 additions & 23 deletions autoflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import os
import re
import signal
import string
import sys
import tokenize

Expand Down Expand Up @@ -250,10 +251,6 @@ def multiline_import(line, previous_line=''):
if symbol in line:
return True

# Ignore doctests.
if line.lstrip().startswith('>'):
return True

return multiline_statement(line, previous_line)


Expand All @@ -271,6 +268,182 @@ def multiline_statement(line, previous_line=''):
return True


class PendingFix(object):
"""Allows a rewrite operation to span multiple lines.

In the main rewrite loop, every time a helper function returns a
``PendingFix`` object instead of a string, this object will be called
with the following line.
"""

def __init__(self, line):
"""Analyse and store the first line."""
self.accumulator = collections.deque([line])

def __call__(self, line):
"""Process line considering the accumulator.

Return self to keep processing the following lines or a string
with the final result of all the lines processed at once.
"""
raise NotImplementedError("Abstract method needs to be overwritten")


def _valid_char_in_line(char, line):
"""Return True if a char appears in the line and is not commented."""
comment_index = line.find('#')
char_index = line.find(char)
valid_char_in_line = (
char_index >= 0 and
(comment_index > char_index or comment_index < 0)
)
return valid_char_in_line


def _top_module(module_name):
"""Return the name of the top level module in the hierarchy."""
if module_name[0] == '.':
return '%LOCAL_MODULE%'
return module_name.split('.')[0]


def _modules_to_remove(unused_modules, safe_to_remove=SAFE_IMPORTS):
"""Discard unused modules that are not safe to remove from the list."""
return [x for x in unused_modules if _top_module(x) in safe_to_remove]


def _segment_module(segment):
"""Extract the module identifier inside the segment.

It might be the case the segment does not have a module (e.g. is composed
just by a parenthesis or line continuation and whitespace). In this
scenario we just keep the segment... These characters are not valid in
identifiers, so they will never be contained in the list of unused modules
anyway.
"""
return segment.strip(string.whitespace + ',\\()') or segment


class FilterMultilineImport(PendingFix):
"""Remove unused imports from multiline import statements.

This class handles both the cases: "from imports" and "direct imports".

Some limitations exist (e.g. imports with comments, lines joined by ``;``,
etc). In these cases, the statement is left unchanged to avoid problems.
"""

IMPORT_RE = re.compile(r'\bimport\b\s*')
INDENTATION_RE = re.compile(r'^\s*')
BASE_RE = re.compile(r'\bfrom\s+([^ ]+)')
SEGMENT_RE = re.compile(
r'([^,\s]+(?:[\s\\]+as[\s\\]+[^,\s]+)?[,\s\\)]*)', re.M)
# ^ module + comma + following space (including new line and continuation)
IDENTIFIER_RE = re.compile(r'[^,\s]+')

def __init__(self, line, unused_module=(), remove_all_unused_imports=False,
safe_to_remove=SAFE_IMPORTS, previous_line=''):
"""Receive the same parameters as ``filter_unused_import``."""
self.remove = unused_module
self.parenthesized = '(' in line
self.from_, imports = self.IMPORT_RE.split(line, maxsplit=1)
match = self.BASE_RE.search(self.from_)
self.base = match.group(1) if match else None
self.give_up = False

if not remove_all_unused_imports:
if self.base and _top_module(self.base) not in safe_to_remove:
self.give_up = True
else:
self.remove = _modules_to_remove(self.remove, safe_to_remove)

if '\\' in previous_line:
# Ignore tricky things like "try: \<new line> import" ...
self.give_up = True

self.analyze(line)

PendingFix.__init__(self, imports)

def is_over(self, line=None):
"""Return True if the multiline import statement is over."""
line = line or self.accumulator[-1]

if self.parenthesized:
return _valid_char_in_line(')', line)

return not _valid_char_in_line('\\', line)

def analyze(self, line):
"""Decide if the statement will be fixed or left unchanged."""
if any(ch in line for ch in ';:#'):
self.give_up = True

def fix(self, accumulated):
"""Given a collection of accumulated lines, fix the entire import."""
old_imports = ''.join(accumulated)
ending = get_line_ending(old_imports)
# Split imports into segments that contain the module name +
# comma + whitespace and eventual <newline> \ ( ) chars
segments = [x for x in self.SEGMENT_RE.findall(old_imports) if x]
modules = [_segment_module(x) for x in segments]
keep = _filter_imports(modules, self.base, self.remove)

# Short-circuit if no import was discarded
if len(keep) == len(segments):
return self.from_ + 'import ' + ''.join(accumulated)

fixed = ''
if keep:
# Since it is very difficult to deal with all the line breaks and
# continuations, let's use the code layout that already exists and
# just replace the module identifiers inside the first N-1 segments
# + the last segment
templates = list(zip(modules, segments))
templates = templates[:len(keep)-1] + templates[-1:]
# It is important to keep the last segment, since it might contain
# important chars like `)`
fixed = ''.join(
template.replace(module, keep[i])
for i, (module, template) in enumerate(templates)
)

# Fix the edge case: inline parenthesis + just one surviving import
if self.parenthesized and any(ch not in fixed for ch in '()'):
fixed = fixed.strip(string.whitespace + '()') + ending

# Replace empty imports with a "pass" statement
empty = len(fixed.strip(string.whitespace + '\\(),')) < 1
if empty:
indentation = self.INDENTATION_RE.search(self.from_).group(0)
return indentation + 'pass' + ending

return self.from_ + 'import ' + fixed

def __call__(self, line=None):
"""Accumulate all the lines in the import and then trigger the fix."""
if line:
self.accumulator.append(line)
self.analyze(line)
if not self.is_over(line):
return self
if self.give_up:
return self.from_ + 'import ' + ''.join(self.accumulator)

return self.fix(self.accumulator)


def _filter_imports(imports, parent=None, unused_module=()):
# We compare full module name (``a.module`` not `module`) to
# guarantee the exact same module as detected from pyflakes.
sep = '' if parent and parent[-1] == '.' else '.'

def full_name(name):
return name if parent is None else parent + sep + name

return [x for x in imports if full_name(x) not in unused_module]


def filter_from_import(line, unused_module):
"""Parse and filter ``from something import a, b, c``.

Expand All @@ -282,15 +455,8 @@ def filter_from_import(line, unused_module):
base_module = re.search(pattern=r'\bfrom\s+([^ ]+)',
string=indentation).group(1)

# Create an imported module list with base module name
# ex ``from a import b, c as d`` -> ``['a.b', 'a.c as d']``
imports = re.split(pattern=r',', string=imports.strip())
imports = [base_module + '.' + x.strip() for x in imports]

# We compare full module name (``a.module`` not `module`) to
# guarantee the exact same module as detected from pyflakes.
filtered_imports = [x.replace(base_module + '.', '')
for x in imports if x not in unused_module]
imports = re.split(pattern=r'\s*,\s*', string=imports.strip())
filtered_imports = _filter_imports(imports, base_module, unused_module)

# All of the import in this statement is unused
if not filtered_imports:
Expand Down Expand Up @@ -387,26 +553,32 @@ def filter_code(source, additional_imports=None,

sio = io.StringIO(source)
previous_line = ''
result = None
for line_number, line in enumerate(sio.readlines(), start=1):
if '#' in line:
yield line
if isinstance(result, PendingFix):
result = result(line)
elif '#' in line:
result = line
elif line_number in marked_import_line_numbers:
yield filter_unused_import(
result = filter_unused_import(
line,
unused_module=marked_unused_module[line_number],
remove_all_unused_imports=remove_all_unused_imports,
imports=imports,
previous_line=previous_line)
elif line_number in marked_variable_line_numbers:
yield filter_unused_variable(line)
result = filter_unused_variable(line)
elif line_number in marked_key_line_numbers:
yield filter_duplicate_key(line, line_messages[line_number],
line_number, marked_key_line_numbers,
source)
result = filter_duplicate_key(line, line_messages[line_number],
line_number, marked_key_line_numbers,
source)
elif line_number in marked_star_import_line_numbers:
yield filter_star_import(line, undefined_names)
result = filter_star_import(line, undefined_names)
else:
yield line
result = line

if not isinstance(result, PendingFix):
yield result

previous_line = line

Expand All @@ -428,9 +600,16 @@ def filter_star_import(line, marked_star_import_undefined_name):
def filter_unused_import(line, unused_module, remove_all_unused_imports,
imports, previous_line=''):
"""Return line if used, otherwise return None."""
if multiline_import(line, previous_line):
# Ignore doctests.
if line.lstrip().startswith('>'):
return line

if multiline_import(line, previous_line):
filt = FilterMultilineImport(line, unused_module,
remove_all_unused_imports,
imports, previous_line)
return filt()

is_from_import = line.lstrip().startswith('from')

if ',' in line and not is_from_import:
Expand Down
Loading