Skip to content

Commit

Permalink
rewrite typeddict even with total= option
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored and asottile committed Apr 9, 2021
1 parent 1923d1c commit 84490d8
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
48 changes: 39 additions & 9 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,11 @@ def visit_Assign(self, node: ast.Assign) -> None:
'TypedDict',
) and
len(node.value.args) == 1 and
len(node.value.keywords) > 0
len(node.value.keywords) > 0 and
not any(
keyword.arg == 'total'
for keyword in node.value.keywords
)
):
self.kw_typed_dicts[ast_to_offset(node)] = node.value
elif (
Expand All @@ -654,7 +658,17 @@ def visit_Assign(self, node: ast.Assign) -> None:
'TypedDict',
) and
len(node.value.args) == 2 and
not node.value.keywords and
(
not node.value.keywords or
(
len(node.value.keywords) == 1 and
node.value.keywords[0].arg == 'total' and
isinstance(
node.value.keywords[0].value,
(ast.Constant, ast.NameConstant),
)
)
) and
isinstance(node.value.args[1], ast.Dict) and
node.value.args[1].keys and
all(
Expand Down Expand Up @@ -718,12 +732,12 @@ def _to_fstring(src: str, call: ast.Call) -> str:
return unparse_parsed_string(parts)


def _replace_typed_class(
def _typed_class_replacement(
tokens: List[Token],
i: int,
call: ast.Call,
types: Dict[str, ast.expr],
) -> None:
) -> Tuple[int, str]:
if i > 0 and tokens[i - 1].name in {'INDENT', UNIMPORTANT_WS}:
indent = f'{tokens[i - 1].src}{" " * 4}'
else:
Expand All @@ -736,8 +750,7 @@ def _replace_typed_class(
end += 1

attrs = '\n'.join(f'{indent}{k}: {_unparse(v)}' for k, v in types.items())
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]
return end, attrs


def _fix_py36_plus(contents_text: str) -> str:
Expand Down Expand Up @@ -788,14 +801,18 @@ def _fix_py36_plus(contents_text: str) -> str:
tup.elts[0].s: tup.elts[1] # type: ignore # (checked above)
for tup in call.args[1].elts # type: ignore # (checked above)
}
_replace_typed_class(tokens, i, call, types)
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]
elif token.offset in visitor.kw_typed_dicts and token.name == 'NAME':
call = visitor.kw_typed_dicts[token.offset]
types = {
arg.arg: arg.value # type: ignore # (checked above)
for arg in call.keywords
}
_replace_typed_class(tokens, i, call, types)
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]
elif token.offset in visitor.dict_typed_dicts and token.name == 'NAME':
call = visitor.dict_typed_dicts[token.offset]
types = {
Expand All @@ -805,7 +822,20 @@ def _fix_py36_plus(contents_text: str) -> str:
call.args[1].values, # type: ignore # (checked above)
)
}
_replace_typed_class(tokens, i, call, types)
if call.keywords:
total = call.keywords[0].value.value # type: ignore # (checked above) # noqa: E501
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = (
f'class {tokens[i].src}('
f'{_unparse(call.func)}, total={total}'
f'):\n'
f'{attrs}'
)
tokens[i:end] = [Token('CODE', src)]
else:
end, attrs = _typed_class_replacement(tokens, i, call, types)
src = f'class {tokens[i].src}({_unparse(call.func)}):\n{attrs}'
tokens[i:end] = [Token('CODE', src)]

return tokens_to_src(tokens)

Expand Down
24 changes: 24 additions & 0 deletions tests/features/typing_typed_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
'D = typing.TypedDict("D", **types)',
id='starstarkwargs',
),
pytest.param(
'D = typing.TypedDict("D", x=int, total=False)',
id='kw_typed_dict with total',
),
),
)
def test_typing_typed_dict_noop(s):
Expand Down Expand Up @@ -78,6 +82,16 @@ def test_typing_typed_dict_noop(s):
id='TypedDict from dict literal',
),
pytest.param(
'import typing\n'
'D = typing.TypedDict("D", {"a": int}, total=False)\n',
'import typing\n'
'class D(typing.TypedDict, total=False):\n'
' a: int\n',
id='TypedDict from dict literal with total',
),
pytest.param(
'from typing_extensions import TypedDict\n'
'D = TypedDict("D", a=int)\n',
Expand All @@ -98,6 +112,16 @@ def test_typing_typed_dict_noop(s):
id='keyword TypedDict from typing_extensions',
),
pytest.param(
'import typing_extensions\n'
'D = typing_extensions.TypedDict("D", {"a": int}, total=True)\n',
'import typing_extensions\n'
'class D(typing_extensions.TypedDict, total=True):\n'
' a: int\n',
id='keyword TypedDict from typing_extensions, with total',
),
pytest.param(
'from typing import List\n'
'from typing_extensions import TypedDict\n'
Expand Down

0 comments on commit 84490d8

Please sign in to comment.