diff --git a/Lib/_pyrepl/simple_interact.py b/Lib/_pyrepl/simple_interact.py
index 31b2097a78a2268..d65b6d0d62790ac 100644
--- a/Lib/_pyrepl/simple_interact.py
+++ b/Lib/_pyrepl/simple_interact.py
@@ -30,6 +30,7 @@
import linecache
import sys
import code
+import ast
from types import ModuleType
from .readline import _get_reader, multiline_input
@@ -74,9 +75,36 @@ def __init__(
super().__init__(locals=locals, filename=filename, local_exit=local_exit) # type: ignore[call-arg]
self.can_colorize = _colorize.can_colorize()
+ def showsyntaxerror(self, filename=None):
+ super().showsyntaxerror(colorize=self.can_colorize)
+
def showtraceback(self):
super().showtraceback(colorize=self.can_colorize)
+ def runsource(self, source, filename="", symbol="single"):
+ try:
+ tree = ast.parse(source)
+ except (OverflowError, SyntaxError, ValueError):
+ self.showsyntaxerror(filename)
+ return False
+ if tree.body:
+ *_, last_stmt = tree.body
+ for stmt in tree.body:
+ wrapper = ast.Interactive if stmt is last_stmt else ast.Module
+ the_symbol = symbol if stmt is last_stmt else "exec"
+ item = wrapper([stmt])
+ try:
+ code = compile(item, filename, the_symbol)
+ except (OverflowError, ValueError):
+ self.showsyntaxerror(filename)
+ return False
+
+ if code is None:
+ return True
+
+ self.runcode(code)
+ return False
+
def run_multiline_interactive_console(
mainmodule: ModuleType | None= None, future_flags: int = 0
@@ -144,10 +172,7 @@ def more_lines(unicodetext: str) -> bool:
input_name = f""
linecache._register_code(input_name, statement, "") # type: ignore[attr-defined]
- symbol = "single" if not contains_pasted_code else "exec"
- more = console.push(_strip_final_indent(statement), filename=input_name, _symbol=symbol) # type: ignore[call-arg]
- if contains_pasted_code and more:
- more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg]
+ more = console.push(_strip_final_indent(statement), filename=input_name, _symbol="single") # type: ignore[call-arg]
assert not more
input_n += 1
except KeyboardInterrupt:
diff --git a/Lib/code.py b/Lib/code.py
index 9d124563f728c23..d1b8a3256cfeaf3 100644
--- a/Lib/code.py
+++ b/Lib/code.py
@@ -94,7 +94,7 @@ def runcode(self, code):
except:
self.showtraceback()
- def showsyntaxerror(self, filename=None):
+ def showsyntaxerror(self, filename=None, **kwargs):
"""Display the syntax error that just occurred.
This doesn't display a stack trace because there isn't one.
@@ -106,6 +106,7 @@ def showsyntaxerror(self, filename=None):
The output is written by self.write(), below.
"""
+ colorize = kwargs.pop('colorize', False)
type, value, tb = sys.exc_info()
sys.last_exc = value
sys.last_type = type
@@ -123,7 +124,7 @@ def showsyntaxerror(self, filename=None):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_exc = sys.last_value = value
if sys.excepthook is sys.__excepthook__:
- lines = traceback.format_exception_only(type, value)
+ lines = traceback.format_exception_only(type, value, colorize=colorize)
self.write(''.join(lines))
else:
# If someone has set sys.excepthook, we let that take precedence
diff --git a/Lib/test/test_pyrepl.py b/Lib/test/test_pyrepl.py
index c8990b699b214cb..c3c95ed9737f400 100644
--- a/Lib/test/test_pyrepl.py
+++ b/Lib/test/test_pyrepl.py
@@ -1,13 +1,15 @@
import itertools
import os
import rlcompleter
-import sys
import tempfile
import unittest
from code import InteractiveConsole
from functools import partial
from unittest import TestCase
from unittest.mock import MagicMock, patch
+from textwrap import dedent
+import contextlib
+import io
from test.support import requires
from test.support.import_helper import import_module
@@ -1002,5 +1004,88 @@ def test_up_arrow_after_ctrl_r(self):
self.assert_screen_equals(reader, "")
+class TestSimpleInteract(unittest.TestCase):
+ def test_multiple_statements(self):
+ namespace = {}
+ code = dedent("""\
+ class A:
+ def foo(self):
+
+
+ pass
+
+ class B:
+ def bar(self):
+ pass
+
+ a = 1
+ a
+ """)
+ console = InteractiveColoredConsole(namespace, filename="")
+ with (
+ patch.object(InteractiveColoredConsole, "showsyntaxerror") as showsyntaxerror,
+ patch.object(InteractiveColoredConsole, "runsource", wraps=console.runsource) as runsource,
+ ):
+ more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg]
+ self.assertFalse(more)
+ showsyntaxerror.assert_not_called()
+
+
+ def test_multiple_statements_output(self):
+ namespace = {}
+ code = dedent("""\
+ b = 1
+ b
+ a = 1
+ a
+ """)
+ console = InteractiveColoredConsole(namespace, filename="")
+ f = io.StringIO()
+ with contextlib.redirect_stdout(f):
+ more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg]
+ self.assertFalse(more)
+ self.assertEqual(f.getvalue(), "1\n")
+
+ def test_empty(self):
+ namespace = {}
+ code = ""
+ console = InteractiveColoredConsole(namespace, filename="")
+ f = io.StringIO()
+ with contextlib.redirect_stdout(f):
+ more = console.push(code, filename="", _symbol="single") # type: ignore[call-arg]
+ self.assertFalse(more)
+ self.assertEqual(f.getvalue(), "")
+
+ def test_runsource_compiles_and_runs_code(self):
+ console = InteractiveColoredConsole()
+ source = "print('Hello, world!')"
+ with patch.object(console, "runcode") as mock_runcode:
+ console.runsource(source)
+ mock_runcode.assert_called_once()
+
+ def test_runsource_returns_false_for_successful_compilation(self):
+ console = InteractiveColoredConsole()
+ source = "print('Hello, world!')"
+ result = console.runsource(source)
+ self.assertFalse(result)
+
+ def test_runsource_returns_false_for_failed_compilation(self):
+ console = InteractiveColoredConsole()
+ source = "print('Hello, world!'"
+ result = console.runsource(source)
+ self.assertFalse(result)
+
+ def test_runsource_shows_syntax_error_for_failed_compilation(self):
+ console = InteractiveColoredConsole()
+ source = "print('Hello, world!'"
+ with patch.object(console, "showsyntaxerror") as mock_showsyntaxerror:
+ console.runsource(source)
+ mock_showsyntaxerror.assert_called_once()
+
+
+if __name__ == '__main__':
+ unittest.main()
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/traceback.py b/Lib/traceback.py
index 9401b461497cc1e..280d92d04cac9b4 100644
--- a/Lib/traceback.py
+++ b/Lib/traceback.py
@@ -155,7 +155,7 @@ def format_exception(exc, /, value=_sentinel, tb=_sentinel, limit=None, \
return list(te.format(chain=chain, colorize=colorize))
-def format_exception_only(exc, /, value=_sentinel, *, show_group=False):
+def format_exception_only(exc, /, value=_sentinel, *, show_group=False, **kwargs):
"""Format the exception part of a traceback.
The return value is a list of strings, each ending in a newline.
@@ -170,10 +170,11 @@ def format_exception_only(exc, /, value=_sentinel, *, show_group=False):
:exc:`BaseExceptionGroup`, the nested exceptions are included as
well, recursively, with indentation relative to their nesting depth.
"""
+ colorize = kwargs.get("colorize", False)
if value is _sentinel:
value = exc
te = TracebackException(type(value), value, None, compact=True)
- return list(te.format_exception_only(show_group=show_group))
+ return list(te.format_exception_only(show_group=show_group, colorize=colorize))
# -- not official API but folk probably use these two functions.