Skip to content
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ What's New in astroid 2.6.3?
============================
Release date: TBA


* Fix a bad inferenece type for yield values inside of a derived class.

Closes PyCQA/astroid#1090

* Fix a crash when the node is a 'Module' in the brain builtin inference

Closes PyCQA/pylint#4671
Expand Down
9 changes: 7 additions & 2 deletions astroid/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import collections

from astroid import context as contextmod
from astroid import util
from astroid import decorators, util
from astroid.const import BUILTINS, PY310_PLUS
from astroid.exceptions import (
AstroidTypeError,
Expand Down Expand Up @@ -543,9 +543,14 @@ class Generator(BaseInstance):

special_attributes = util.lazy_descriptor(objectmodel.GeneratorModel)

def __init__(self, parent=None):
def __init__(self, parent=None, generator_initial_context=None):
super().__init__()
self.parent = parent
self._call_context = contextmod.copy_context(generator_initial_context)

@decorators.cached
def infer_yield_types(self):
yield from self.parent.infer_yield_result(self._call_context)

def callable(self):
return False
Expand Down
16 changes: 1 addition & 15 deletions astroid/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,22 +489,8 @@ def _infer_context_manager(self, mgr, context):
# It doesn't interest us.
raise InferenceError(node=func)

# Get the first yield point. If it has multiple yields,
# then a RuntimeError will be raised.
yield next(inferred.infer_yield_types())

possible_yield_points = func.nodes_of_class(nodes.Yield)
# Ignore yields in nested functions
yield_point = next(
(node for node in possible_yield_points if node.scope() == func), None
)
if yield_point:
if not yield_point.value:
const = nodes.Const(None)
const.parent = yield_point
const.lineno = yield_point.lineno
yield const
else:
yield from yield_point.value.infer(context=context)
elif isinstance(inferred, bases.Instance):
try:
enter = next(inferred.igetattr("__enter__", context=context))
Expand Down
17 changes: 16 additions & 1 deletion astroid/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,21 @@ def is_generator(self):
"""
return bool(next(self._get_yield_nodes_skip_lambdas(), False))

def infer_yield_result(self, context=None):
"""Infer what the function yields when called

:returns: What the function yields
:rtype: iterable(NodeNG or Uninferable) or None
"""
for yield_ in self.nodes_of_class(node_classes.Yield):
if yield_.value is None:
const = node_classes.Const(None)
const.parent = yield_
const.lineno = yield_.lineno
yield const
elif yield_.scope() == self:
yield from yield_.value.infer(context=context)

def infer_call_result(self, caller=None, context=None):
"""Infer what the function returns when called.

Expand All @@ -1719,7 +1734,7 @@ def infer_call_result(self, caller=None, context=None):
generator_cls = bases.AsyncGenerator
else:
generator_cls = bases.Generator
result = generator_cls(self)
result = generator_cls(self, generator_initial_context=context)
yield result
return
# This is really a gigantic hack to work around metaclass generators
Expand Down
21 changes: 21 additions & 0 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6154,5 +6154,26 @@ def test_issue926_binop_referencing_same_name_is_not_uninferable():
assert inferred[0].value == 3


def test_issue_1090_infer_yield_type_base_class():
code = """
import contextlib

class A:
@contextlib.contextmanager
def get(self):
yield self

class B(A):
def play():
pass

with B().get() as b:
b
b
"""
node = extract_node(code)
assert next(node.infer()).pytype() == ".B"


if __name__ == "__main__":
unittest.main()