Skip to content

Commit

Permalink
Infer typing.cast return depending on whether argument is inferable
Browse files Browse the repository at this point in the history
In the case of typing.cast(A, x) where x is inferable, we prefer to
yield the type of x since we don't want to lose information. If x is
uninferable, then treating it as an instance of A is reasonable.
  • Loading branch information
timmartin committed Jul 8, 2021
1 parent 1d5c62e commit 4045b6e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
9 changes: 7 additions & 2 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# Copyright (c) 2021 hippo91 <guillaume.peillex@gmail.com>

"""Astroid hooks for typing.py support."""
import itertools
import typing
from functools import partial

Expand Down Expand Up @@ -365,8 +366,12 @@ def infer_typing_cast(
if func.qname() != "typing.cast" or len(node.args) != 2:
raise UseInferenceDefault

type_node = next(node.args[0].infer(context=ctx))
return iter([Instance(type_node)])
peek_value_node, value_node = itertools.tee(node.args[1].infer(context=ctx))
if next(peek_value_node) is Uninferable:
type_node = next(node.args[0].infer(context=ctx))
return iter([Instance(type_node)])
else:
return value_node


AstroidManager().register_transform(
Expand Down
20 changes: 19 additions & 1 deletion tests/unittest_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,7 +1801,8 @@ def test_typing_object_builtin_subscriptable(self):
self.assertIsInstance(inferred, nodes.ClassDef)
self.assertIsInstance(inferred.getattr("__iter__")[0], nodes.FunctionDef)

def test_typing_cast(self):
def test_typing_cast_uninferable(self):
"""cast will yield instance of casted-to type if not otherwise inferable"""
node = builder.extract_node(
"""
from typing import cast
Expand All @@ -1817,6 +1818,23 @@ class A:
assert isinstance(inferred, bases.Instance)
assert inferred.name == "A"

def test_typing_cast_inferable(self):
"""cast leaves type unchanged if it has been inferred"""
node = builder.extract_node(
"""
from typing import cast
class A:
pass
b = list()
a = cast(A, b)
a
"""
)
inferred = next(node.infer())
assert isinstance(inferred, bases.Instance)
assert inferred.name == "list"

def test_typing_cast_attribute(self):
node = builder.extract_node(
"""
Expand Down

0 comments on commit 4045b6e

Please sign in to comment.