Skip to content

Commit

Permalink
Turn create_pipe_input into a context manager (breaking change).
Browse files Browse the repository at this point in the history
This makes context managers of the following:
- `create_pipe_input()`
- `PosixPipeInput`
- `Win32PipeInput`

The reason for this change is that the close method of the pipe should only
close the write-end, and as a consequence of that, the read-end should trigger
an `EOFError`. Before this change, the read-end was also closed, and that
caused the key input to never wake up and "read" the end-of-file. However, we
still want to close the read end at some point, and that's why this is a
context manager now.

As part of this change, exceptions that are raised in the TelnetServer interact
method, won't cause cause the whole server to crash.

See also: #1585

Co-Author: Frank Wu <kwyd@163.com>
  • Loading branch information
jonathanslenders committed Mar 9, 2022
1 parent 045e2b2 commit 97ac514
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 89 deletions.
11 changes: 2 additions & 9 deletions docs/pages/advanced_topics/unit_testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,14 @@ In the following example we use a
from prompt_toolkit.output import DummyOutput
def test_prompt_session():
inp = create_pipe_input()
try:
with create_pipe_input() as inp:
inp.send_text("hello\n")
session = PromptSession(
input=inp,
output=DummyOutput(),
)
result = session.prompt()
finally:
inp.close()
assert result == "hello"
Expand Down Expand Up @@ -116,12 +112,9 @@ single fixture that does it for every test. Something like this:
@pytest.fixture(autouse=True, scope="function")
def mock_input():
pipe_input = create_pipe_input()
try:
with create_pipe_input() as pipe_input:
with create_app_session(input=pipe_input, output=DummyOutput()):
yield pipe_input
finally:
pipe_input.close()
Type checking
Expand Down
29 changes: 17 additions & 12 deletions prompt_toolkit/contrib/ssh/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prompt_toolkit.application.current import AppSession, create_app_session
from prompt_toolkit.data_structures import Size
from prompt_toolkit.eventloop import get_event_loop
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.input import PipeInput, create_pipe_input
from prompt_toolkit.output.vt100 import Vt100_Output

__all__ = ["PromptToolkitSSHSession", "PromptToolkitSSHServer"]
Expand All @@ -28,7 +28,7 @@ def __init__(
# PipInput object, for sending input in the CLI.
# (This is something that we can use in the prompt_toolkit event loop,
# but still write date in manually.)
self._input = create_pipe_input()
self._input: Optional[PipeInput] = None
self._output: Optional[Vt100_Output] = None

# Output object. Don't render to the real stdout, but write everything
Expand Down Expand Up @@ -88,16 +88,17 @@ async def _interact(self) -> None:
self._output = Vt100_Output(
self.stdout, self._get_size, term=term, write_binary=False
)
with create_app_session(input=self._input, output=self._output) as session:
self.app_session = session
try:
await self.interact(self)
except BaseException:
traceback.print_exc()
finally:
# Close the connection.
self._chan.close()
self._input.close()
with create_pipe_input() as self._input:
with create_app_session(input=self._input, output=self._output) as session:
self.app_session = session
try:
await self.interact(self)
except BaseException:
traceback.print_exc()
finally:
# Close the connection.
self._chan.close()
self._input.close()

def terminal_size_changed(
self, width: int, height: int, pixwidth: object, pixheight: object
Expand All @@ -107,6 +108,10 @@ def terminal_size_changed(
self.app_session.app._on_resize()

def data_received(self, data: str, datatype: object) -> None:
if self._input is None:
# Should not happen.
return

self._input.send_text(data)


Expand Down
52 changes: 31 additions & 21 deletions prompt_toolkit/contrib/telnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from prompt_toolkit.data_structures import Size
from prompt_toolkit.eventloop import get_event_loop
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
from prompt_toolkit.input import create_pipe_input
from prompt_toolkit.input import PipeInput, create_pipe_input
from prompt_toolkit.output.vt100 import Vt100_Output
from prompt_toolkit.renderer import print_formatted_text as print_formatted_text
from prompt_toolkit.styles import BaseStyle, DummyStyle
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(self, connection: socket.socket, encoding: str) -> None:
self._connection = connection
self._errors = "strict"
self._buffer: List[bytes] = []
self._closed = False

def write(self, data: str) -> None:
data = data.replace("\n", "\r\n")
Expand All @@ -104,6 +105,9 @@ def flush(self) -> None:

self._buffer = []

def close(self) -> None:
self._closed = True

@property
def encoding(self) -> str:
return self._encoding
Expand All @@ -126,6 +130,7 @@ def __init__(
server: "TelnetServer",
encoding: str,
style: Optional[BaseStyle],
vt100_input: PipeInput,
) -> None:

self.conn = conn
Expand All @@ -136,6 +141,7 @@ def __init__(
self.style = style
self._closed = False
self._ready = asyncio.Event()
self.vt100_input = vt100_input
self.vt100_output = None

# Create "Output" object.
Expand All @@ -144,9 +150,6 @@ def __init__(
# Initialize.
_initialize_telnet(conn)

# Create input.
self.vt100_input = create_pipe_input()

# Create output.
def get_size() -> Size:
return self.size
Expand Down Expand Up @@ -197,12 +200,6 @@ def handle_incoming_data() -> None:
with create_app_session(input=self.vt100_input, output=self.vt100_output):
self.context = contextvars.copy_context()
await self.interact(self)
except Exception as e:
print("Got %s" % type(e).__name__, e)
import traceback

traceback.print_exc()
raise
finally:
self.close()

Expand All @@ -222,6 +219,7 @@ def close(self) -> None:
self.vt100_input.close()
get_event_loop().remove_reader(self.conn)
self.conn.close()
self.stdout.close()

def send(self, formatted_text: AnyFormattedText) -> None:
"""
Expand Down Expand Up @@ -336,22 +334,34 @@ def _accept(self) -> None:
conn, addr = self._listen_socket.accept()
logger.info("New connection %r %r", *addr)

connection = TelnetConnection(
conn, addr, self.interact, self, encoding=self.encoding, style=self.style
)
self.connections.add(connection)

# Run application for this connection.
async def run() -> None:
logger.info("Starting interaction %r %r", *addr)
try:
await connection.run_application()
except Exception as e:
print(e)
with create_pipe_input() as vt100_input:
connection = TelnetConnection(
conn,
addr,
self.interact,
self,
encoding=self.encoding,
style=self.style,
vt100_input=vt100_input,
)
self.connections.add(connection)

logger.info("Starting interaction %r %r", *addr)
try:
await connection.run_application()
finally:
self.connections.remove(connection)
logger.info("Stopping interaction %r %r", *addr)
except BaseException as e:
print("Got %s" % type(e).__name__, e)
import traceback

traceback.print_exc()
finally:
self.connections.remove(connection)
self._application_tasks.remove(task)
logger.info("Stopping interaction %r %r", *addr)

task = get_event_loop().create_task(run())
self._application_tasks.append(task)
3 changes: 2 additions & 1 deletion prompt_toolkit/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .base import DummyInput, Input
from .base import DummyInput, Input, PipeInput
from .defaults import create_input, create_pipe_input

__all__ = [
# Base.
"Input",
"PipeInput",
"DummyInput",
# Defaults.
"create_input",
Expand Down
1 change: 1 addition & 0 deletions prompt_toolkit/input/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

__all__ = [
"Input",
"PipeInput",
"DummyInput",
]

Expand Down
16 changes: 12 additions & 4 deletions prompt_toolkit/input/defaults.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Optional, TextIO
from typing import ContextManager, Optional, TextIO

from prompt_toolkit.utils import is_windows

Expand Down Expand Up @@ -48,16 +48,24 @@ def create_input(
return Vt100Input(stdin)


def create_pipe_input() -> PipeInput:
def create_pipe_input() -> ContextManager[PipeInput]:
"""
Create an input pipe.
This is mostly useful for unit testing.
Usage::
with create_pipe_input() as input:
input.send_text('inputdata')
Breaking change: In prompt_toolkit 3.0.28 and earlier, this was returning
the `PipeInput` directly, rather than through a context manager.
"""
if is_windows():
from .win32_pipe import Win32PipeInput

return Win32PipeInput()
return Win32PipeInput.create()
else:
from .posix_pipe import PosixPipeInput

return PosixPipeInput()
return PosixPipeInput.create()
70 changes: 55 additions & 15 deletions prompt_toolkit/input/posix_pipe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import ContextManager, TextIO, cast
from contextlib import contextmanager
from typing import ContextManager, Iterator, TextIO, cast

from ..utils import DummyContext
from .base import PipeInput
Expand All @@ -10,6 +11,36 @@
]


class _Pipe:
"Wrapper around os.pipe, that ensures we don't double close any end."

def __init__(self) -> None:
self.read_fd, self.write_fd = os.pipe()
self._read_closed = False
self._write_closed = False

def close_read(self) -> None:
"Close read-end if not yet closed."
if self._read_closed:
return

os.close(self.read_fd)
self._read_closed = True

def close_write(self) -> None:
"Close write-end if not yet closed."
if self._write_closed:
return

os.close(self.write_fd)
self._write_closed = True

def close(self) -> None:
"Close both read and write ends."
self.close_read()
self.close_write()


class PosixPipeInput(Vt100Input, PipeInput):
"""
Input that is send through a pipe.
Expand All @@ -18,14 +49,15 @@ class PosixPipeInput(Vt100Input, PipeInput):
Usage::
input = PosixPipeInput()
input.send_text('inputdata')
with PosixPipeInput.create() as input:
input.send_text('inputdata')
"""

_id = 0

def __init__(self, text: str = "") -> None:
self._r, self._w = os.pipe()
def __init__(self, _pipe: _Pipe, _text: str = "") -> None:
# Private constructor. Users should use the public `.create()` method.
self.pipe = _pipe

class Stdin:
encoding = "utf-8"
Expand All @@ -34,21 +66,30 @@ def isatty(stdin) -> bool:
return True

def fileno(stdin) -> int:
return self._r
return self.pipe.read_fd

super().__init__(cast(TextIO, Stdin()))
self.send_text(text)
self.send_text(_text)

# Identifier for every PipeInput for the hash.
self.__class__._id += 1
self._id = self.__class__._id

@classmethod
@contextmanager
def create(cls, text: str = "") -> Iterator["PosixPipeInput"]:
pipe = _Pipe()
try:
yield PosixPipeInput(_pipe=pipe, _text=text)
finally:
pipe.close()

def send_bytes(self, data: bytes) -> None:
os.write(self._w, data)
os.write(self.pipe.write_fd, data)

def send_text(self, data: str) -> None:
"Send text to the input."
os.write(self._w, data.encode("utf-8"))
os.write(self.pipe.write_fd, data.encode("utf-8"))

def raw_mode(self) -> ContextManager[None]:
return DummyContext()
Expand All @@ -58,12 +99,11 @@ def cooked_mode(self) -> ContextManager[None]:

def close(self) -> None:
"Close pipe fds."
os.close(self._r)
os.close(self._w)

# We should assign `None` to 'self._r` and 'self._w',
# The event loop still needs to know the the fileno for this input in order
# to properly remove it from the selectors.
# Only close the write-end of the pipe. This will unblock the reader
# callback (in vt100.py > _attached_input), which eventually will raise
# `EOFError`. If we'd also close the read-end, then the event loop
# won't wake up the corresponding callback because of this.
self.pipe.close_write()

def typeahead_hash(self) -> str:
"""
Expand Down
Loading

1 comment on commit 97ac514

@wuf
Copy link
Contributor

@wuf wuf commented on 97ac514 Mar 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commmit fixes #1522

Please sign in to comment.