Skip to content

Commit

Permalink
Fix race in FileResponse if file is replaced during prepare (#10101)
Browse files Browse the repository at this point in the history
(cherry picked from commit 678993a)
  • Loading branch information
bdraco authored and patchback[bot] committed Dec 4, 2024
1 parent f180fc1 commit 1bd4cf1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGES/10101.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed race condition in :class:`aiohttp.web.FileResponse` that could have resulted in an incorrect response if the file was replaced on the file system during ``prepare`` -- by :user:`bdraco`.
79 changes: 58 additions & 21 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import io
import os
import pathlib
import sys
Expand All @@ -16,6 +17,7 @@
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -73,6 +75,9 @@
CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined]


_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()


class FileResponse(StreamResponse):
"""A response object can be used to send files."""

Expand Down Expand Up @@ -161,10 +166,10 @@ async def _precondition_failed(
self.content_length = 0
return await super().prepare(request)

def _get_file_path_stat_encoding(
def _open_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
"""Return the file path, stat result, and encoding.
) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]:
"""Return the io object, stat result, and encoding.
If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
Expand All @@ -182,31 +187,72 @@ def _get_file_path_stat_encoding(
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
if S_ISREG(st.st_mode):
return compressed_path, st, file_encoding
fobj = compressed_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, file_encoding

# Fallback to the uncompressed file
return file_path, file_path.stat(), None
st = file_path.stat()
if not S_ISREG(st.st_mode):
return None, st, None
fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, None

async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
file_path, st, file_encoding = await loop.run_in_executor(
None, self._get_file_path_stat_encoding, accept_encoding
fobj, st, file_encoding = await loop.run_in_executor(
None, self._open_file_path_stat_encoding, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
except OSError:
# Most likely to be FileNotFoundError or OSError for circular
# symlinks in python >= 3.13, so respond with 404.
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)

# Forbid special files like sockets, pipes, devices, etc.
if not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
try:
# Forbid special files like sockets, pipes, devices, etc.
if not fobj or not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

return await self._prepare_open_file(request, fobj, st, file_encoding)
finally:
if fobj:
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)

async def _prepare_open_file(
self,
request: "BaseRequest",
fobj: io.BufferedReader,
st: os.stat_result,
file_encoding: Optional[str],
) -> Optional[AbstractStreamWriter]:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

Expand Down Expand Up @@ -349,18 +395,9 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)

try:
fobj = await loop.run_in_executor(None, file_path.open, "rb")
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

if start: # be aware that start could be None or int=0 here.
offset = start
else:
offset = 0

try:
return await self._sendfile(request, fobj, offset, count)
finally:
await asyncio.shield(loop.run_in_executor(None, fobj.close))
return await self._sendfile(request, fobj, offset, count)
7 changes: 4 additions & 3 deletions tests/test_web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,17 @@ async def test_access_mock_special_resource(
my_special.touch()

real_result = my_special.stat()
real_stat = pathlib.Path.stat
real_stat = os.stat

def mock_stat(self: pathlib.Path, **kwargs: Any) -> os.stat_result:
s = real_stat(self, **kwargs)
def mock_stat(path: Any, **kwargs: Any) -> os.stat_result:
s = real_stat(path, **kwargs)
if os.path.samestat(s, real_result):
mock_mode = S_IFIFO | S_IMODE(s.st_mode)
s = os.stat_result([mock_mode] + list(s)[1:])
return s

monkeypatch.setattr("pathlib.Path.stat", mock_stat)
monkeypatch.setattr("os.stat", mock_stat)

app = web.Application()
app.router.add_static("/", str(tmp_path))
Expand Down

0 comments on commit 1bd4cf1

Please sign in to comment.