diff --git a/rewrite/rewrite/python/format/spaces_visitor.py b/rewrite/rewrite/python/format/spaces_visitor.py index 6f0dd7f0..9b67dc74 100644 --- a/rewrite/rewrite/python/format/spaces_visitor.py +++ b/rewrite/rewrite/python/format/spaces_visitor.py @@ -6,7 +6,7 @@ MethodDeclaration, Empty, ArrayAccess, Space, If, Block, ClassDeclaration, VariableDeclarations, JRightPadded, \ Import from rewrite.python import PythonVisitor, SpacesStyle, Binary, ChainedAssignment, Slice, CollectionLiteral, \ - ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree + ForLoop, DictLiteral, KeyValue, TypeHint, MultiImport, ExpressionTypeTree, ComprehensionExpression from rewrite.visitor import P @@ -408,6 +408,43 @@ def visit_expression_type_tree(self, expression_type_tree: ExpressionTypeTree, p ett = space_before(ett, False) return ett + def visit_comprehension_expression(self, comprehension_expression: ComprehensionExpression, p: P) -> J: + ce = cast(ComprehensionExpression, super().visit_comprehension_expression(comprehension_expression, p)) + + # Handle space before result this will depend on the style setting for the comprehension type. + if ce.kind == ComprehensionExpression.Kind.LIST: + ce = ce.with_result(space_before(ce.result, self._style.within.brackets)) + ce = ce.with_suffix(update_space(ce.suffix, self._style.within.brackets)) + elif ce.kind == ComprehensionExpression.Kind.GENERATOR: + ce = ce.with_result(space_before(ce.result, False)) + ce = ce.with_suffix(update_space(ce.suffix, False)) + elif ce.kind in (ComprehensionExpression.Kind.SET, ComprehensionExpression.Kind.DICT): + ce = ce.with_result(space_before(ce.result, self._style.within.braces)) + ce = ce.with_suffix(update_space(ce.suffix, self._style.within.braces)) + + return ce + + def visit_comprehension_condition(self, condition: ComprehensionExpression.Condition, p: P) -> J: + cond = cast(ComprehensionExpression.Condition, super().visit_comprehension_condition(condition, p)) + # Set single space before and after comprehension 'if' keyword. + cond = space_before(cond, True) + cond = cond.with_expression(space_before(cond.expression, True)) + return cond + + def visit_comprehension_clause(self, clause: ComprehensionExpression.Clause, p: P) -> J: + cc = cast(ComprehensionExpression.Clause, super().visit_comprehension_clause(clause, p)) + + # Ensure single space before 'for' keyword + cc = space_before(cc, True) + + # Single before 'in' keyword e.g. ..i in... <-> ...i in... + cc = cc.padding.with_iterated_list(space_before_left_padded(cc.padding.iterated_list, True)) + # Single space before 'iterator' variable (or after for keyword) e.g. ...for i <-> ...for i + cc = cc.with_iterator_variable(space_before(cc.iterator_variable, True)) + # Ensure single space after 'in' keyword e.g. ...in range(10) <-> ...in range(10) + cc = cc.padding.with_iterated_list(space_before_left_padded_element(cc.padding.iterated_list, True)) + return cc + def _remap_trailing_comma_space(self, tc: j.TrailingComma) -> j.TrailingComma: return tc.with_suffix(update_space(tc.suffix, self._style.other.after_comma)) @@ -438,6 +475,9 @@ def space_before_container(container: j.JContainer, add_space: bool) -> j.JConta return container +def space_before_left_padded_element(container: j.JLeftPadded, add_space: bool) -> j.JLeftPadded: + return container.with_element(space_before(container.element, add_space)) + def space_before_right_padded_element(container: j.JRightPadded, add_space: bool) -> j.JRightPadded: return container.with_element(space_before(container.element, add_space)) diff --git a/rewrite/tests/python/all/format/spaces/comprehension_spaces_test.py b/rewrite/tests/python/all/format/spaces/comprehension_spaces_test.py new file mode 100644 index 00000000..5712e217 --- /dev/null +++ b/rewrite/tests/python/all/format/spaces/comprehension_spaces_test.py @@ -0,0 +1,118 @@ +import pytest + +from rewrite.python import IntelliJ, SpacesVisitor +from rewrite.test import rewrite_run, python, RecipeSpec, from_visitor + + +@pytest.mark.parametrize("within_brackets", [False, True]) +def test_spaces_with_list_comprehension(within_brackets): + style = IntelliJ.spaces() + style = style.with_within( + style.within.with_brackets(within_brackets) + ) + _s = " " if within_brackets else "" + rewrite_run( + # language=python + python( + """\ + a = [ i*2 for i in range(0, 10)] + a = [ i*2 for i in [1, 2, 3 ]] + """, + f"""\ + a = [i * 2 for i in range(0, 10)] + a = [i * 2 for i in [1, 2, 3]] + """.replace("[", "[" + _s).replace("]", _s + "]") + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +def test_spaces_with_generator_comprehension(): + style = IntelliJ.spaces() + + rewrite_run( + # language=python + python( + """\ + a = ( i*2 for i in range(0, 10)) + """, + f"""\ + a = (i * 2 for i in range(0, 10)) + """ + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +@pytest.mark.parametrize("within_brackets", [False, True]) +def test_spaces_with_list_comprehension_with_condition(within_brackets): + style = IntelliJ.spaces() + style = style.with_within( + style.within.with_brackets(within_brackets) + ) + _s = " " if within_brackets else "" + rewrite_run( + # language=python + python( + """\ + a = [ i* 2 for i in range(0, 10) if i % 2 == 0 ] + """, + """\ + a = [i * 2 for i in range(0, 10) if i % 2 == 0] + """.replace("[", "[" + _s).replace("]", _s + "]") + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +@pytest.mark.parametrize("within_braces", [False, True]) +def test_spaces_with_set_comprehension(within_braces): + style = IntelliJ.spaces() + style = style.with_within( + style.within.with_braces(within_braces) + ) + _s = " " if within_braces else "" + rewrite_run( + # language=python + python( + """\ + a = {i*2 for i in range(0, 10)} + a = {i for i in {1, 2, 3 }} + """, + """\ + a = {i * 2 for i in range(0, 10)} + a = {i for i in {1, 2, 3}} + """.replace("{", "{" + _s).replace("}", _s + "}") + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + ) + + +@pytest.mark.parametrize("within_braces", [False, True]) +def test_spaces_with_dict_comprehension(within_braces): + style = IntelliJ.spaces() + style = style.with_within( + style.within.with_braces(within_braces) + ) + _s = " " if within_braces else "" + rewrite_run( + # language=python + python( + """\ + a = {i: i*2 for i in range(0, 10)} + a = {i: i for i in [1, 2, 3]} + a = {k: v*2 for k,v in { "a": 2, "b": 4}.items( ) } + """, + """\ + a = {i: i * 2 for i in range(0, 10)} + a = {i: i for i in [1, 2, 3]} + a = {k: v * 2 for k, v in {"a": 2, "b": 4}.items()} + """.replace("{", "{" + _s).replace("}", _s + "}") + ), + spec=RecipeSpec() + .with_recipe(from_visitor(SpacesVisitor(style))) + )