Skip to content

Commit

Permalink
Detect % formatting and directives for bytes and bytearray
Browse files Browse the repository at this point in the history
  • Loading branch information
netromdk committed Jan 12, 2020
1 parent 0ac202f commit 0444434
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 2 deletions.
1 change: 1 addition & 0 deletions runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ def runsuite(suite):
runsuite("exclusions")
runsuite("comment_exclusions")
runsuite("backports")
runsuite("bytes_directive")
15 changes: 15 additions & 0 deletions tests/bytes_directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .testutils import VerminTest, detect, current_version

class VerminBytesDirectiveTests(VerminTest):
def test_b_directive(self):
if current_version() >= 3.5:
self.assertOnlyIn((3, 5), detect("b'%b' % 10"))

def test_a_directive(self):
if current_version() >= 3.5:
self.assertOnlyIn((3, 5), detect("b'%a' % 'x'"))

def test_r_directive(self):
v = current_version()
if v < 3 or v >= 3.5:
self.assertOnlyIn(((2, 7), (3, 5)), detect("b'%r' % 'x'"))
33 changes: 33 additions & 0 deletions tests/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,36 @@ def test_generalized_unpacking(self):
visitor = visit("function(**{'x': 42}, arg=84)")
self.assertTrue(visitor.generalized_unpacking())
self.assertOnlyIn((3, 5), visitor.minimum_versions())

def test_bytes_format(self):
v = current_version()
if v < 3 or v >= 3.5:
visitor = visit("b'%x' % 10")
self.assertTrue(visitor.bytes_format())
self.assertOnlyIn(((2, 6), (3, 5)), visitor.minimum_versions())

def test_bytearray_format(self):
if current_version() >= 3.5:
visitor = visit("bytearray(b'%x') % 10")
self.assertTrue(visitor.bytearray_format())
self.assertOnlyIn((3, 5), visitor.minimum_versions())

def test_bytes_directives(self):
visitor = visit("b'%b %x'")
self.assertOnlyIn(("b", "x"), visitor.bytes_directives())
visitor = visit("b'%4b'")
self.assertOnlyIn(("b",), visitor.bytes_directives())
visitor = visit("b'%4b'")
self.assertOnlyIn(("b",), visitor.bytes_directives())
visitor = visit("b'%#4b'")
self.assertOnlyIn(("b",), visitor.bytes_directives())
visitor = visit("b'%04b'")
self.assertOnlyIn(("b",), visitor.bytes_directives())
visitor = visit("b'%.4f'")
self.assertOnlyIn(("f",), visitor.bytes_directives())
visitor = visit("b'%-4f'")
self.assertOnlyIn(("f",), visitor.bytes_directives())
visitor = visit("b'% f'")
self.assertOnlyIn(("f",), visitor.bytes_directives())
visitor = visit("b'%+f'")
self.assertOnlyIn(("f",), visitor.bytes_directives())
6 changes: 6 additions & 0 deletions vermin/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,6 +2935,12 @@ def MOD_MEM_REQS():
"f": ((2, 6), (3, 0)),
"u": (None, (3, 6)),
}

# bytes/bytearray (and str for 2 compatibility) requirements: directive -> requiresments
BYTES_REQS = {
"a": (None, (3, 5)),
"b": (None, (3, 5)),
"r": ((2, 7), (3, 5)),
}

# array.array typecode requirements: typecode -> requirements
Expand Down
68 changes: 66 additions & 2 deletions vermin/source_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import re
import sys

from .rules import MOD_REQS, MOD_MEM_REQS, KWARGS_REQS, STRFTIME_REQS, ARRAY_TYPECODE_REQS, \
CODECS_ERROR_HANDLERS, CODECS_ERRORS_INDICES, CODECS_ENCODINGS, CODECS_ENCODINGS_INDICES
from .rules import MOD_REQS, MOD_MEM_REQS, KWARGS_REQS, STRFTIME_REQS, BYTES_REQS,\
ARRAY_TYPECODE_REQS, CODECS_ERROR_HANDLERS, CODECS_ERRORS_INDICES, CODECS_ENCODINGS,\
CODECS_ENCODINGS_INDICES
from .config import Config
from .utility import dotted_name, reverse_range, combine_versions, version_strings

STRFTIME_DIRECTIVE_REGEX = re.compile(r"%(?:[-\.\d#\s\+])*(\w)")
BYTES_DIRECTIVE_REGEX = STRFTIME_DIRECTIVE_REGEX

class SourceVisitor(ast.NodeVisitor):
def __init__(self, config=None):
Expand Down Expand Up @@ -53,10 +55,13 @@ def __init__(self, config=None):
self.__depth = 0
self.__line = 1
self.__strftime_directives = []
self.__bytes_directives = []
self.__codecs_error_handlers = []
self.__codecs_encodings = []
self.__with_statement = False
self.__generalized_unpacking = False
self.__bytes_format = False
self.__bytearray_format = False

# Imported members of modules, like "exc_clear" of "sys".
self.__import_mem_mod = {}
Expand Down Expand Up @@ -151,6 +156,9 @@ def pos_only_args(self):
def strftime_directives(self):
return self.__strftime_directives

def bytes_directives(self):
return self.__bytes_directives

def user_defined(self):
return self.__user_defs

Expand Down Expand Up @@ -187,6 +195,12 @@ def with_statement(self):
def generalized_unpacking(self):
return self.__generalized_unpacking

def bytes_format(self):
return self.__bytes_format

def bytearray_format(self):
return self.__bytearray_format

def minimum_versions(self):
mins = [(0, 0), (0, 0)]

Expand Down Expand Up @@ -280,13 +294,28 @@ def minimum_versions(self):
if self.generalized_unpacking():
mins = combine_versions(mins, (None, (3, 5)))

if self.bytes_format():
# Since byte strings are a `str` synonym as of 2.6+, and thus also supports `%` formatting,
# (2, 6) is returned instead of None.
mins = combine_versions(mins, ((2, 6), (3, 5)))

if self.bytearray_format():
mins = combine_versions(mins, (None, (3, 5)))

for directive in self.strftime_directives():
if directive in STRFTIME_REQS:
vers = STRFTIME_REQS[directive]
self.__vvprint("strftime directive '{}' requires {}".
format(directive, version_strings(vers)), directive)
mins = combine_versions(mins, vers)

for directive in self.bytes_directives():
if directive in BYTES_REQS:
vers = BYTES_REQS[directive]
self.__vvprint("bytes directive '{}' requires {}".
format(directive, version_strings(vers)), directive)
mins = combine_versions(mins, vers)

for typecode in self.array_typecodes():
if typecode in ARRAY_TYPECODE_REQS:
vers = ARRAY_TYPECODE_REQS[typecode]
Expand Down Expand Up @@ -430,6 +459,10 @@ def __add_strftime_directive(self, group, line=None, col=None):
self.__strftime_directives.append(group)
self.__add_line_col(group, line, col)

def __add_bytes_directive(self, group, line=None, col=None):
self.__bytes_directives.append(group)
self.__add_line_col(group, line, col)

def __add_codecs_error_handler(self, func, node):
if func in CODECS_ERRORS_INDICES:
idx = CODECS_ERRORS_INDICES[func]
Expand Down Expand Up @@ -852,11 +885,42 @@ def visit_Bytes(self, node):
self.__bytesv3 = True
self.__vvprint("byte strings (b'..') require 3+ (or 2.6+ as `str` synonym)")

if hasattr(node, "s"):
for directive in BYTES_DIRECTIVE_REGEX.findall(str(node.s)):
self.__add_bytes_directive(directive, node.lineno)

def visit_Str(self, node):
# As bytes to str fallback in python 2, add bytes formatting directives.
if sys.version_info.major == 2 and hasattr(node, "s"):
for directive in BYTES_DIRECTIVE_REGEX.findall(node.s):
self.__add_bytes_directive(directive, node.lineno)

def visit_BinOp(self, node):
# Examples:
# BinOp(left=Bytes(s=b'%4x'), op=Mod(), right=Num(n=10))
# BinOp(left=Call(func=Name(id='bytearray', ctx=Load()), args=[Bytes(s=b'%x')], keywords=[]),
# op=Mod(), right=Num(n=10))
if ((hasattr(ast, "Bytes") and isinstance(node.left, ast.Bytes)) or
isinstance(node.left, ast.Str)) and isinstance(node.op, ast.Mod):
self.__bytes_format = True
self.__vvprint("bytes `%` formatting requires 3.5+ (or 2.6+ as `str` synonym)")

if (isinstance(node.left, ast.Call) and isinstance(node.left.func, ast.Name) and
node.left.func.id == "bytearray") and isinstance(node.op, ast.Mod):
self.__bytearray_format = True
self.__vvprint("bytearray `%` formatting requires 3.5+")

self.generic_visit(node)

def visit_Constant(self, node):
# From 3.8, Bytes(s=b'%x') is represented as Constant(value=b'%x', kind=None) instead.
if hasattr(node, "value") and type(node.value) == bytes:
self.__bytesv3 = True
self.__vvprint("byte strings (b'..') require 3+")

for directive in BYTES_DIRECTIVE_REGEX.findall(str(node.value)):
self.__add_bytes_directive(directive, node.lineno)

def visit_JoinedStr(self, node):
self.__fstrings = True
self.__vvprint("f-strings require 3.6+")
Expand Down

0 comments on commit 0444434

Please sign in to comment.