From 84490d87fc42297d98cc63d704e8b8735f9fc7a9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sat, 3 Apr 2021 14:59:17 +0100 Subject: [PATCH] rewrite typeddict even with total= option --- pyupgrade/_main.py | 48 +++++++++++++++++++----- tests/features/typing_typed_dict_test.py | 24 ++++++++++++ 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/pyupgrade/_main.py b/pyupgrade/_main.py index 86c3a97f..e3035001 100644 --- a/pyupgrade/_main.py +++ b/pyupgrade/_main.py @@ -644,7 +644,11 @@ def visit_Assign(self, node: ast.Assign) -> None: 'TypedDict', ) and len(node.value.args) == 1 and - len(node.value.keywords) > 0 + len(node.value.keywords) > 0 and + not any( + keyword.arg == 'total' + for keyword in node.value.keywords + ) ): self.kw_typed_dicts[ast_to_offset(node)] = node.value elif ( @@ -654,7 +658,17 @@ def visit_Assign(self, node: ast.Assign) -> None: 'TypedDict', ) and len(node.value.args) == 2 and - not node.value.keywords and + ( + not node.value.keywords or + ( + len(node.value.keywords) == 1 and + node.value.keywords[0].arg == 'total' and + isinstance( + node.value.keywords[0].value, + (ast.Constant, ast.NameConstant), + ) + ) + ) and isinstance(node.value.args[1], ast.Dict) and node.value.args[1].keys and all( @@ -718,12 +732,12 @@ def _to_fstring(src: str, call: ast.Call) -> str: return unparse_parsed_string(parts) -def _replace_typed_class( +def _typed_class_replacement( tokens: List[Token], i: int, call: ast.Call, types: Dict[str, ast.expr], -) -> None: +) -> Tuple[int, str]: if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}: indent = f'{tokens[i - 1].src}{" " * 4}' else: @@ -736,8 +750,7 @@ def _replace_typed_class( end += 1 attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items()) - src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' - tokens[i:end] = [Token('CODE', src)] + return end, attrs def _fix_py36_plus(contents_text: str) -> str: @@ -788,14 +801,18 @@ def _fix_py36_plus(contents_text: str) -> str: tup.elts[0].s: tup.elts[1] # type: ignore # (checked above) for tup in call.args[1].elts # type: ignore # (checked above) } - _replace_typed_class(tokens, i, call, types) + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME': call = visitor.kw_typed_dicts[token.offset] types = { arg.arg: arg.value # type: ignore # (checked above) for arg in call.keywords } - _replace_typed_class(tokens, i, call, types) + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME': call = visitor.dict_typed_dicts[token.offset] types = { @@ -805,7 +822,20 @@ def _fix_py36_plus(contents_text: str) -> str: call.args[1].values, # type: ignore # (checked above) ) } - _replace_typed_class(tokens, i, call, types) + if call.keywords: + total = call.keywords[0].value.value # type: ignore # (checked above) # noqa: E501 + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = ( + f'class {tokens[i].src}(' + f'{_unparse(call.func)}, total={total}' + f'):\n' + f'{attrs}' + ) + tokens[i:end] = [Token('CODE', src)] + else: + end, attrs = _typed_class_replacement(tokens, i, call, types) + src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}' + tokens[i:end] = [Token('CODE', src)] return tokens_to_src(tokens) diff --git a/tests/features/typing_typed_dict_test.py b/tests/features/typing_typed_dict_test.py index da286836..f3a67335 100644 --- a/tests/features/typing_typed_dict_test.py +++ b/tests/features/typing_typed_dict_test.py @@ -39,6 +39,10 @@ 'D = typing.TypedDict("D", **types)', id='starstarkwargs', ), + pytest.param( + 'D = typing.TypedDict("D", x=int, total=False)', + id='kw_typed_dict with total', + ), ), ) def test_typing_typed_dict_noop(s): @@ -78,6 +82,16 @@ def test_typing_typed_dict_noop(s): id='TypedDict from dict literal', ), + pytest.param( + 'import typing\n' + 'D = typing.TypedDict("D", {"a": int}, total=False)\n', + + 'import typing\n' + 'class D(typing.TypedDict, total=False):\n' + ' a: int\n', + + id='TypedDict from dict literal with total', + ), pytest.param( 'from typing_extensions import TypedDict\n' 'D = TypedDict("D", a=int)\n', @@ -98,6 +112,16 @@ def test_typing_typed_dict_noop(s): id='keyword TypedDict from typing_extensions', ), + pytest.param( + 'import typing_extensions\n' + 'D = typing_extensions.TypedDict("D", {"a": int}, total=True)\n', + + 'import typing_extensions\n' + 'class D(typing_extensions.TypedDict, total=True):\n' + ' a: int\n', + + id='keyword TypedDict from typing_extensions, with total', + ), pytest.param( 'from typing import List\n' 'from typing_extensions import TypedDict\n'