diff --git a/mypy_boto3_builder/parsers/shape_parser.py b/mypy_boto3_builder/parsers/shape_parser.py index 55f3f06f..20292a70 100644 --- a/mypy_boto3_builder/parsers/shape_parser.py +++ b/mypy_boto3_builder/parsers/shape_parser.py @@ -77,6 +77,7 @@ get_type_def_name, xform_name, ) +from mypy_boto3_builder.utils.type_checks import is_union class ShapeParser: @@ -1309,8 +1310,11 @@ def convert_input_arguments_to_unions(self, methods: Sequence[Method]) -> None: ) for type_annotation in parent_type_annotations: parents = type_annotation.find_type_annotation_parents(input_typed_dict) + if not parents: + continue + for parent in sorted(parents): - if parent is union_type_annotation: + if is_union(parent) and parent.name == union_name: continue self.logger.debug( f"Adding output shape to {parent.render()} type:" diff --git a/mypy_boto3_builder/type_annotations/type_parent.py b/mypy_boto3_builder/type_annotations/type_parent.py index e5d0d929..a00bcff4 100644 --- a/mypy_boto3_builder/type_annotations/type_parent.py +++ b/mypy_boto3_builder/type_annotations/type_parent.py @@ -41,26 +41,50 @@ def get_children_types(self) -> set[FakeAnnotation]: def find_type_annotation_parents( self, type_annotation: FakeAnnotation, - skip: Iterable[FakeAnnotation] = (), ) -> "set[TypeParent]": """ Check recursively if child is present in type def. """ - result: set[TypeParent] = set() - for child_type in self.iterate_children_type_annotations(): - if child_type == type_annotation: - result.add(self) - if not isinstance(child_type, TypeParent): - continue + return self.find_type_annotation_parent_map({type_annotation}).get(type_annotation) or set() - if child_type in skip: - continue + def find_type_annotation_parent_map( + self, type_annotations: Iterable[FakeAnnotation] + ) -> "dict[FakeAnnotation, set[TypeParent]]": + """ + Check recursively if children are present in type def. + + Can be used for non-overlapping type annotations. + """ + result: dict[FakeAnnotation, set[TypeParent]] = {} + for parent in self.find_parents(): + for child_type in parent.iterate_children_type_annotations(): + if child_type not in type_annotations: + continue + + if child_type not in result: + result[child_type] = set() + + result[child_type].add(parent) + + return result + + def find_parents(self) -> "set[TypeParent]": + """ + Find all parents recursively including self. + """ + result: set[TypeParent] = {self} + stack: list[TypeParent] = [self] + + while stack: + current = stack.pop() + for child_type in current.iterate_children_type_annotations(): + if not isinstance(child_type, TypeParent): + continue + if child_type in result: + continue - parents = child_type.find_type_annotation_parents( - type_annotation, - skip={*skip, child_type}, - ) - result.update(parents) + result.add(child_type) + stack.append(child_type) return result diff --git a/tests/type_annotations/test_type_subscript.py b/tests/type_annotations/test_type_subscript.py index e9a08a63..5145431a 100644 --- a/tests/type_annotations/test_type_subscript.py +++ b/tests/type_annotations/test_type_subscript.py @@ -55,6 +55,13 @@ def test_find_type_annotation_parents(self) -> None: assert outer.find_type_annotation_parents(Type.str) == {outer} assert outer.find_type_annotation_parents(Type.List) == set() + def test_find_type_annotation_parent_map(self) -> None: + inner = TypeSubscript(Type.List, [Type.int]) + outer = TypeSubscript(Type.Dict, [Type.str, inner]) + assert outer.find_type_annotation_parent_map([Type.int]) == {Type.int: {inner}} + assert outer.find_type_annotation_parent_map([Type.str]) == {Type.str: {outer}} + assert outer.find_type_annotation_parent_map([Type.List]) == {} + def test_replace_child(self) -> None: inner = TypeSubscript(Type.List, [Type.int]) outer = TypeSubscript(Type.Dict, [Type.str, inner])