Skip to content

Commit

Permalink
Merge pull request #482 from MarcoGorelli/await
Browse files Browse the repository at this point in the history
Upgrade to f-string get broken for Py3.6 in case await keyword is in
  • Loading branch information
asottile authored Jul 10, 2021
2 parents 513fe79 + 039c9df commit df2cf61
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
22 changes: 17 additions & 5 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,13 +520,22 @@ def _format_params(call: ast.Call) -> Set[str]:
return params


def _contains_await(node: ast.AST) -> bool:
for node_ in ast.walk(node):
if isinstance(node_, ast.Await):
return True
else:
return False


class FindPy36Plus(ast.NodeVisitor):
def __init__(self) -> None:
def __init__(self, *, min_version: Version) -> None:
self.fstrings: Dict[Offset, ast.Call] = {}
self.named_tuples: Dict[Offset, ast.Call] = {}
self.dict_typed_dicts: Dict[Offset, ast.Call] = {}
self.kw_typed_dicts: Dict[Offset, ast.Call] = {}
self._from_imports: Dict[str, Set[str]] = collections.defaultdict(set)
self.min_version = min_version

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.level == 0 and node.module in {'typing', 'typing_extensions'}:
Expand Down Expand Up @@ -591,7 +600,8 @@ def visit_Call(self, node: ast.Call) -> None:
if not candidate:
i += 1
else:
self.fstrings[ast_to_offset(node)] = node
if self.min_version >= (3, 7) or not _contains_await(node):
self.fstrings[ast_to_offset(node)] = node

self.generic_visit(node)

Expand Down Expand Up @@ -758,13 +768,13 @@ def _typed_class_replacement(
return end, attrs


def _fix_py36_plus(contents_text: str) -> str:
def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str:
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
return contents_text

visitor = FindPy36Plus()
visitor = FindPy36Plus(min_version=min_version)
visitor.visit(ast_obj)

if not any((
Expand Down Expand Up @@ -871,7 +881,9 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int:
)
contents_text = _fix_tokens(contents_text, min_version=args.min_version)
if args.min_version >= (3, 6):
contents_text = _fix_py36_plus(contents_text)
contents_text = _fix_py36_plus(
contents_text, min_version=args.min_version,
)

if filename == '-':
print(contents_text, end='')
Expand Down
13 changes: 11 additions & 2 deletions tests/features/fstrings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@
r'''"{}".format(a['\\'])''',
'"{}".format(a["b"])',
"'{}'.format(a['b'])",
# await only becomes keyword in Python 3.7+
"async def c(): return '{}'.format(await 3)",
"async def c(): return '{}'.format(1 + await 3)",
),
)
def test_fix_fstrings_noop(s):
assert _fix_py36_plus(s) == s
assert _fix_py36_plus(s, min_version=(3, 6)) == s


@pytest.mark.parametrize(
Expand All @@ -60,4 +63,10 @@ def test_fix_fstrings_noop(s):
),
)
def test_fix_fstrings(s, expected):
assert _fix_py36_plus(s) == expected
assert _fix_py36_plus(s, min_version=(3, 6)) == expected


def test_fix_fstrings_await_py37():
s = "async def c(): return '{}'.format(await 1+foo())"
expected = "async def c(): return f'{await 1+foo()}'"
assert _fix_py36_plus(s, min_version=(3, 7)) == expected
4 changes: 2 additions & 2 deletions tests/features/typing_named_tuple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
),
)
def test_typing_named_tuple_noop(s):
assert _fix_py36_plus(s) == s
assert _fix_py36_plus(s, min_version=(3, 6)) == s


@pytest.mark.parametrize(
Expand Down Expand Up @@ -171,4 +171,4 @@ def test_typing_named_tuple_noop(s):
),
)
def test_fix_typing_named_tuple(s, expected):
assert _fix_py36_plus(s) == expected
assert _fix_py36_plus(s, min_version=(3, 6)) == expected
4 changes: 2 additions & 2 deletions tests/features/typing_typed_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
),
)
def test_typing_typed_dict_noop(s):
assert _fix_py36_plus(s) == s
assert _fix_py36_plus(s, min_version=(3, 6)) == s


@pytest.mark.parametrize(
Expand Down Expand Up @@ -137,4 +137,4 @@ def test_typing_typed_dict_noop(s):
),
)
def test_typing_typed_dict(s, expected):
assert _fix_py36_plus(s) == expected
assert _fix_py36_plus(s, min_version=(3, 6)) == expected

0 comments on commit df2cf61

Please sign in to comment.