diff --git a/Lib/_pyrepl/readline.py b/Lib/_pyrepl/readline.py index e3444d90477d352..8a5b926fbb1e651 100644 --- a/Lib/_pyrepl/readline.py +++ b/Lib/_pyrepl/readline.py @@ -95,7 +95,8 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader): # Instance fields config: ReadlineConfig - more_lines: Callable[[str], bool] | None = None + more_lines: MoreLinesCallable | None = None + last_used_indentation: str | None = None def __post_init__(self) -> None: super().__post_init__() @@ -154,6 +155,11 @@ def get_trimmed_history(self, maxlength: int) -> list[str]: cut = 0 return self.history[cut:] + def update_last_used_indentation(self) -> None: + indentation = _get_first_indentation(self.buffer) + if indentation is not None: + self.last_used_indentation = indentation + # --- simplified support for reading multiline Python statements --- def collect_keymap(self) -> tuple[tuple[KeySpec, CommandName], ...]: @@ -208,6 +214,28 @@ def _get_previous_line_indent(buffer: list[str], pos: int) -> tuple[int, int | N return prevlinestart, indent +def _get_first_indentation(buffer: list[str]) -> str | None: + indented_line_start = None + for i in range(len(buffer)): + if (i < len(buffer) - 1 + and buffer[i] == "\n" + and buffer[i + 1] in " \t" + ): + indented_line_start = i + 1 + elif indented_line_start is not None and buffer[i] not in " \t\n": + return ''.join(buffer[indented_line_start : i]) + return None + + +def _is_last_char_colon(buffer: list[str]) -> bool: + i = len(buffer) + while i > 0: + i -= 1 + if buffer[i] not in " \t\n": # ignore whitespaces + return buffer[i] == ":" + return False + + class maybe_accept(commands.Command): def do(self) -> None: r: ReadlineAlikeReader @@ -224,9 +252,18 @@ def do(self) -> None: # auto-indent the next line like the previous line prevlinestart, indent = _get_previous_line_indent(r.buffer, r.pos) r.insert("\n") - if not self.reader.paste_mode and indent: - for i in range(prevlinestart, prevlinestart + indent): - r.insert(r.buffer[i]) + if not self.reader.paste_mode: + if indent: + for i in range(prevlinestart, prevlinestart + indent): + r.insert(r.buffer[i]) + r.update_last_used_indentation() + if _is_last_char_colon(r.buffer): + if r.last_used_indentation is not None: + indentation = r.last_used_indentation + else: + # default + indentation = " " * 4 + r.insert(indentation) elif not self.reader.paste_mode: self.finish = True else: diff --git a/Lib/test/test_pyrepl/test_pyrepl.py b/Lib/test/test_pyrepl/test_pyrepl.py index b643ae5895c97e8..930f6759fb0b482 100644 --- a/Lib/test/test_pyrepl/test_pyrepl.py +++ b/Lib/test/test_pyrepl/test_pyrepl.py @@ -5,19 +5,31 @@ from unittest import TestCase from unittest.mock import patch -from .support import FakeConsole, handle_all_events, handle_events_narrow_console -from .support import more_lines, multiline_input, code_to_events +from .support import ( + FakeConsole, + handle_all_events, + handle_events_narrow_console, + more_lines, + multiline_input, + code_to_events, +) from _pyrepl.console import Event from _pyrepl.readline import ReadlineAlikeReader, ReadlineConfig from _pyrepl.readline import multiline_input as readline_multiline_input class TestCursorPosition(TestCase): + def prepare_reader(self, events): + console = FakeConsole(events) + config = ReadlineConfig(readline_completer=None) + reader = ReadlineAlikeReader(console=console, config=config) + return reader + def test_up_arrow_simple(self): # fmt: off code = ( - 'def f():\n' - ' ...\n' + "def f():\n" + " ...\n" ) # fmt: on events = itertools.chain( @@ -34,8 +46,8 @@ def test_up_arrow_simple(self): def test_down_arrow_end_of_input(self): # fmt: off code = ( - 'def f():\n' - ' ...\n' + "def f():\n" + " ...\n" ) # fmt: on events = itertools.chain( @@ -300,6 +312,79 @@ def test_cursor_position_after_wrap_and_move_up(self): self.assertEqual(reader.pos, 10) self.assertEqual(reader.cxy, (1, 1)) + def test_auto_indent_default(self): + # fmt: off + input_code = ( + "def f():\n" + "pass\n\n" + ) + + output_code = ( + "def f():\n" + " pass\n" + " " + ) + # fmt: on + + def test_auto_indent_continuation(self): + # auto indenting according to previous user indentation + # fmt: off + events = itertools.chain( + code_to_events("def f():\n"), + # add backspace to delete default auto-indent + [ + Event(evt="key", data="backspace", raw=bytearray(b"\x7f")), + ], + code_to_events( + " pass\n" + "pass\n\n" + ), + ) + + output_code = ( + "def f():\n" + " pass\n" + " pass\n" + " " + ) + # fmt: on + + reader = self.prepare_reader(events) + output = multiline_input(reader) + self.assertEqual(output, output_code) + + def test_auto_indent_prev_block(self): + # auto indenting according to indentation in different block + # fmt: off + events = itertools.chain( + code_to_events("def f():\n"), + # add backspace to delete default auto-indent + [ + Event(evt="key", data="backspace", raw=bytearray(b"\x7f")), + ], + code_to_events( + " pass\n" + "pass\n\n" + ), + code_to_events( + "def g():\n" + "pass\n\n" + ), + ) + + + output_code = ( + "def g():\n" + " pass\n" + " " + ) + # fmt: on + + reader = self.prepare_reader(events) + output1 = multiline_input(reader) + output2 = multiline_input(reader) + self.assertEqual(output2, output_code) + class TestPyReplOutput(TestCase): def prepare_reader(self, events): @@ -316,14 +401,12 @@ def test_basic(self): def test_multiline_edit(self): events = itertools.chain( - code_to_events("def f():\n ...\n\n"), + code_to_events("def f():\n...\n\n"), [ Event(evt="key", data="up", raw=bytearray(b"\x1bOA")), Event(evt="key", data="up", raw=bytearray(b"\x1bOA")), Event(evt="key", data="up", raw=bytearray(b"\x1bOA")), Event(evt="key", data="right", raw=bytearray(b"\x1bOC")), - Event(evt="key", data="right", raw=bytearray(b"\x1bOC")), - Event(evt="key", data="right", raw=bytearray(b"\x1bOC")), Event(evt="key", data="backspace", raw=bytearray(b"\x7f")), Event(evt="key", data="g", raw=bytearray(b"g")), Event(evt="key", data="down", raw=bytearray(b"\x1bOB")), @@ -334,9 +417,9 @@ def test_multiline_edit(self): reader = self.prepare_reader(events) output = multiline_input(reader) - self.assertEqual(output, "def f():\n ...\n ") + self.assertEqual(output, "def f():\n ...\n ") output = multiline_input(reader) - self.assertEqual(output, "def g():\n ...\n ") + self.assertEqual(output, "def g():\n ...\n ") def test_history_navigation_with_up_arrow(self): events = itertools.chain( @@ -485,6 +568,7 @@ class Dummy: @property def test_func(self): import warnings + warnings.warn("warnings\n") return None @@ -508,12 +592,12 @@ def prepare_reader(self, events): def test_paste(self): # fmt: off code = ( - 'def a():\n' - ' for x in range(10):\n' - ' if x%2:\n' - ' print(x)\n' - ' else:\n' - ' pass\n' + "def a():\n" + " for x in range(10):\n" + " if x%2:\n" + " print(x)\n" + " else:\n" + " pass\n" ) # fmt: on @@ -534,10 +618,10 @@ def test_paste(self): def test_paste_mid_newlines(self): # fmt: off code = ( - 'def f():\n' - ' x = y\n' - ' \n' - ' y = z\n' + "def f():\n" + " x = y\n" + " \n" + " y = z\n" ) # fmt: on @@ -558,16 +642,16 @@ def test_paste_mid_newlines(self): def test_paste_mid_newlines_not_in_paste_mode(self): # fmt: off code = ( - 'def f():\n' - ' x = y\n' - ' \n' - ' y = z\n\n' + "def f():\n" + "x = y\n" + "\n" + "y = z\n\n" ) expected = ( - 'def f():\n' - ' x = y\n' - ' ' + "def f():\n" + " x = y\n" + " " ) # fmt: on @@ -579,20 +663,20 @@ def test_paste_mid_newlines_not_in_paste_mode(self): def test_paste_not_in_paste_mode(self): # fmt: off input_code = ( - 'def a():\n' - ' for x in range(10):\n' - ' if x%2:\n' - ' print(x)\n' - ' else:\n' - ' pass\n\n' + "def a():\n" + "for x in range(10):\n" + "if x%2:\n" + "print(x)\n" + "else:\n" + "pass\n\n" ) output_code = ( - 'def a():\n' - ' for x in range(10):\n' - ' if x%2:\n' - ' print(x)\n' - ' else:' + "def a():\n" + " for x in range(10):\n" + " if x%2:\n" + " print(x)\n" + " else:" ) # fmt: on @@ -605,25 +689,25 @@ def test_bracketed_paste(self): """Test that bracketed paste using \x1b[200~ and \x1b[201~ works.""" # fmt: off input_code = ( - 'def a():\n' - ' for x in range(10):\n' - '\n' - ' if x%2:\n' - ' print(x)\n' - '\n' - ' else:\n' - ' pass\n' + "def a():\n" + " for x in range(10):\n" + "\n" + " if x%2:\n" + " print(x)\n" + "\n" + " else:\n" + " pass\n" ) output_code = ( - 'def a():\n' - ' for x in range(10):\n' - '\n' - ' if x%2:\n' - ' print(x)\n' - '\n' - ' else:\n' - ' pass\n' + "def a():\n" + " for x in range(10):\n" + "\n" + " if x%2:\n" + " print(x)\n" + "\n" + " else:\n" + " pass\n" ) # fmt: on