Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(injector): resolve issue with Injector requires all annotations #188

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions flake8_type_checking/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,17 +525,30 @@ def handle_injector_declaration(self, node: Union[AsyncFunctionDef, FunctionDef]

To achieve this, we just visit the annotations to register them as "uses".
"""
for path in [node.args.args, node.args.kwonlyargs]:
for argument in path:
if hasattr(argument, 'annotation') and argument.annotation:
annotation = argument.annotation
if not hasattr(annotation, 'value'):
continue
value = annotation.value
if hasattr(value, 'id') and value.id == 'Inject':
self.visit(argument.annotation)
if hasattr(value, 'attr') and value.attr == 'Inject':
self.visit(argument.annotation)
if not self._has_injected_annotation(node):
return None

for annotation in self._list_annotations(node):
self.visit(annotation)

def _has_injected_annotation(self, node: Union[AsyncFunctionDef, FunctionDef]) -> bool:
for annotation in self._list_annotations(node):
if not hasattr(annotation, 'value'):
continue
value = annotation.value
if hasattr(value, 'id') and value.id == 'Inject':
return True
if hasattr(value, 'attr') and value.attr == 'Inject':
return True
return False

def _list_annotations(self, node: Union[AsyncFunctionDef, FunctionDef]) -> Iterator[ast.AST]:
Daverball marked this conversation as resolved.
Show resolved Hide resolved
for argument in chain(node.args.args, node.args.kwonlyargs, node.args.posonlyargs):
if annotation := getattr(argument, 'annotation', None):
yield annotation
for arg in (node.args.kwarg, node.args.vararg):
if arg and (annotation := getattr(arg, 'annotation', None)):
yield annotation


class FastAPIMixin:
Expand Down
47 changes: 19 additions & 28 deletions tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,7 @@ def __init__(self, service: Inject[Service]) -> None:
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
@pytest.mark.parametrize(('enabled', 'expected'), [(True, set())])
Copy link
Collaborator

@Daverball Daverball Jun 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrize hardly makes sense if you only provide one parametrization, also the docstring and test name no longer make any sense. It might be better to refactor and merge some of these tests into more of a regression style test where you provide a list of source code samples and the expected output, similar to the TCXXX test cases. That way it will also be easier to cover all the different kinds of arguments (i.e. *args in addition to **kwargs, and signatures that use both * and / in them)

def test_injector_option_only_allows_injected_dependencies(enabled, expected):
"""Whenever an injector option is enabled, only injected dependencies should be ignored."""
example = textwrap.dedent('''
Expand All @@ -76,20 +63,7 @@ def __init__(self, service: Inject[Service], other: OtherDependency) -> None:
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
(True, {'4:0 ' + TC002.format(module='other_dependency.OtherDependency')}),
(
False,
{
'2:0 ' + TC002.format(module='injector.Inject'),
'3:0 ' + TC002.format(module='services.Service'),
'4:0 ' + TC002.format(module='other_dependency.OtherDependency'),
},
),
],
)
@pytest.mark.parametrize(('enabled', 'expected'), [(True, set())])
def test_injector_option_only_allows_injector_slices(enabled, expected):
"""
Whenever an injector option is enabled, only injected dependencies should be ignored,
Expand All @@ -108,6 +82,23 @@ def __init__(self, service: Inject[Service], other_deps: list[OtherDependency])
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(('enabled', 'expected'), [(True, set())])
def test_injector_option_require_injections_under_unpack(enabled, expected):
"""Whenever an injector option is enabled, injected dependencies should be ignored, even if unpacked."""
example = textwrap.dedent("""
from typing import Unpack

from injector import Inject
from services import ServiceKwargs

class X:
def __init__(self, service: Inject[Service], **kwargs: Unpack[ServiceKwargs]) -> None:
self.service = service
self.args = args
""")
assert _get_error(example, error_code_filter='TC002', type_checking_injector_enabled=enabled) == expected


@pytest.mark.parametrize(
('enabled', 'expected'),
[
Expand Down
Loading