diff --git a/pyupgrade/_plugins/open_mode.py b/pyupgrade/_plugins/open_mode.py index 271007a1..a8cad514 100644 --- a/pyupgrade/_plugins/open_mode.py +++ b/pyupgrade/_plugins/open_mode.py @@ -1,6 +1,8 @@ import ast +import functools from typing import Iterable from typing import List +from typing import NamedTuple from typing import Tuple from tokenize_rt import Offset @@ -21,19 +23,28 @@ U_MODE_REPLACE = U_MODE_REPLACE_R | U_MODE_REMOVE_U -def _fix_open_mode(i: int, tokens: List[Token]) -> None: +class FunctionArg(NamedTuple): + arg_idx: int + value: ast.expr + + +def _fix_open_mode(i: int, tokens: List[Token], *, arg_idx: int) -> None: j = find_open_paren(tokens, i) func_args, end = parse_call_args(tokens, j) - mode = tokens_to_src(tokens[slice(*func_args[1])]) - mode_stripped = mode.strip().strip('"\'') + mode = tokens_to_src(tokens[slice(*func_args[arg_idx])]) + mode_stripped = mode.split('=')[-1] + mode_stripped = mode_stripped.strip().strip('"\'') if mode_stripped in U_MODE_REMOVE: - del tokens[func_args[0][1]:func_args[1][1]] + if arg_idx == 0: + del tokens[func_args[arg_idx][0]: func_args[arg_idx + 1][0]] + else: + del tokens[func_args[arg_idx - 1][1]:func_args[arg_idx][1]] elif mode_stripped in U_MODE_REPLACE_R: new_mode = mode.replace('U', 'r') - tokens[slice(*func_args[1])] = [Token('SRC', new_mode)] + tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)] elif mode_stripped in U_MODE_REMOVE_U: new_mode = mode.replace('U', '') - tokens[slice(*func_args[1])] = [Token('SRC', new_mode)] + tokens[slice(*func_args[arg_idx])] = [Token('SRC', new_mode)] else: raise AssertionError(f'unreachable: {mode!r}') @@ -48,11 +59,37 @@ def visit_Call( state.settings.min_version >= (3,) and isinstance(node.func, ast.Name) and node.func.id == 'open' and - not has_starargs(node) and - len(node.args) >= 2 and - isinstance(node.args[1], ast.Str) and ( + not has_starargs(node) + ): + if len(node.args) >= 2 and isinstance(node.args[1], ast.Str): + if ( node.args[1].s in U_MODE_REPLACE or (len(node.args) == 2 and node.args[1].s in U_MODE_REMOVE) + ): + func = functools.partial( + _fix_open_mode, + arg_idx=1, + ) + yield ast_to_offset(node), func + elif node.keywords and (len(node.keywords) + len(node.args) > 1): + mode = next( + ( + FunctionArg(n, keyword.value) + for n, keyword in enumerate(node.keywords) + if keyword.arg == 'mode' + ), + None, ) - ): - yield ast_to_offset(node), _fix_open_mode + if ( + mode is not None and + isinstance(mode.value, ast.Str) and + ( + mode.value.s in U_MODE_REMOVE or + mode.value.s in U_MODE_REPLACE + ) + ): + func = functools.partial( + _fix_open_mode, + arg_idx=len(node.args) + mode.arg_idx, + ) + yield ast_to_offset(node), func diff --git a/tests/features/open_mode_test.py b/tests/features/open_mode_test.py index 92826ad3..bffc854b 100644 --- a/tests/features/open_mode_test.py +++ b/tests/features/open_mode_test.py @@ -9,9 +9,13 @@ ( # already a reduced mode 'open("foo", "w")', + 'open("foo", mode="w")', 'open("foo", "rb")', # nonsense mode 'open("foo", "Uw")', + 'open("foo", qux="r")', + 'open("foo", 3)', + 'open(mode="r")', # TODO: could maybe be rewritten to remove t? 'open("foo", "wt")', # don't remove this, they meant to use `encoding=` @@ -26,12 +30,38 @@ def test_fix_open_mode_noop(s): ('s', 'expected'), ( ('open("foo", "U")', 'open("foo")'), + ('open("foo", mode="U")', 'open("foo")'), ('open("foo", "Ur")', 'open("foo")'), + ('open("foo", mode="Ur")', 'open("foo")'), ('open("foo", "Ub")', 'open("foo", "rb")'), + ('open("foo", mode="Ub")', 'open("foo", mode="rb")'), ('open("foo", "rUb")', 'open("foo", "rb")'), + ('open("foo", mode="rUb")', 'open("foo", mode="rb")'), ('open("foo", "r")', 'open("foo")'), + ('open("foo", mode="r")', 'open("foo")'), ('open("foo", "rt")', 'open("foo")'), + ('open("foo", mode="rt")', 'open("foo")'), ('open("f", "r", encoding="UTF-8")', 'open("f", encoding="UTF-8")'), + ( + 'open("f", mode="r", encoding="UTF-8")', + 'open("f", encoding="UTF-8")', + ), + ( + 'open(file="f", mode="r", encoding="UTF-8")', + 'open(file="f", encoding="UTF-8")', + ), + ( + 'open("f", encoding="UTF-8", mode="r")', + 'open("f", encoding="UTF-8")', + ), + ( + 'open(file="f", encoding="UTF-8", mode="r")', + 'open(file="f", encoding="UTF-8")', + ), + ( + 'open(mode="r", encoding="UTF-8", file="t.py")', + 'open( encoding="UTF-8", file="t.py")', + ), ), ) def test_fix_open_mode(s, expected):