Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in FileResponse if file is replaced during prepare #10101

Merged
merged 17 commits into from
Dec 4, 2024
1 change: 1 addition & 0 deletions CHANGES/10101.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed race condition in :class:`aiohttp.web.FileResponse` that could have resulted in an incorrect response if the file was replaced on the filesystem during ``prepare`` -- by :user:`bdraco`.
bdraco marked this conversation as resolved.
Show resolved Hide resolved
79 changes: 58 additions & 21 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import io
import os
import pathlib
from contextlib import suppress
Expand All @@ -13,6 +14,7 @@
Callable,
Final,
Optional,
Set,
Tuple,
cast,
)
Expand Down Expand Up @@ -69,6 +71,9 @@
CONTENT_TYPES.add_type(content_type, extension)


_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()


class FileResponse(StreamResponse):
"""A response object can be used to send files."""

Expand Down Expand Up @@ -157,10 +162,10 @@ async def _precondition_failed(
self.content_length = 0
return await super().prepare(request)

def _get_file_path_stat_encoding(
def _open_file_path_stat_encoding(
self, accept_encoding: str
) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
"""Return the file path, stat result, and encoding.
) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]:
"""Return the io object, stat result, and encoding.

If an uncompressed file is returned, the encoding is set to
:py:data:`None`.
Expand All @@ -178,31 +183,72 @@ def _get_file_path_stat_encoding(
# Do not follow symlinks and ignore any non-regular files.
st = compressed_path.lstat()
bdraco marked this conversation as resolved.
Show resolved Hide resolved
if S_ISREG(st.st_mode):
return compressed_path, st, file_encoding
fobj = compressed_path.open("rb")
with suppress(OSError):
bdraco marked this conversation as resolved.
Show resolved Hide resolved
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, file_encoding

# Fallback to the uncompressed file
return file_path, file_path.stat(), None
st = file_path.stat()
if not S_ISREG(st.st_mode):
return None, st, None
fobj = file_path.open("rb")
with suppress(OSError):
# fstat() may not be available on all platforms
# Once we open the file, we want the fstat() to ensure
# the file has not changed between the first stat()
# and the open().
st = os.stat(fobj.fileno())
return fobj, st, None

async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
loop = asyncio.get_running_loop()
# Encoding comparisons should be case-insensitive
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
try:
file_path, st, file_encoding = await loop.run_in_executor(
None, self._get_file_path_stat_encoding, accept_encoding
fobj, st, file_encoding = await loop.run_in_executor(
None, self._open_file_path_stat_encoding, accept_encoding
)
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
except OSError:
# Most likely to be FileNotFoundError or OSError for circular
# symlinks in python >= 3.13, so respond with 404.
self.set_status(HTTPNotFound.status_code)
return await super().prepare(request)

# Forbid special files like sockets, pipes, devices, etc.
if not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)
try:
# Forbid special files like sockets, pipes, devices, etc.
if not fobj or not S_ISREG(st.st_mode):
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

return await self._prepare_open_file(request, fobj, st, file_encoding)
finally:
if fobj:
# We do not await here because we do not want to wait
# for the executor to finish before returning the response
# so the connection can begin servicing another request
# as soon as possible.
close_future = loop.run_in_executor(None, fobj.close)
# Hold a strong reference to the future to prevent it from being
# garbage collected before it completes.
_CLOSE_FUTURES.add(close_future)
close_future.add_done_callback(_CLOSE_FUTURES.remove)

async def _prepare_open_file(
self,
request: "BaseRequest",
fobj: io.BufferedReader,
st: os.stat_result,
file_encoding: Optional[str],
) -> Optional[AbstractStreamWriter]:
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
last_modified = st.st_mtime

Expand Down Expand Up @@ -343,18 +389,9 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
if count == 0 or must_be_empty_body(request.method, self.status):
return await super().prepare(request)

try:
fobj = await loop.run_in_executor(None, file_path.open, "rb")
except PermissionError:
self.set_status(HTTPForbidden.status_code)
return await super().prepare(request)

if start: # be aware that start could be None or int=0 here.
offset = start
else:
offset = 0

try:
return await self._sendfile(request, fobj, offset, count)
finally:
await asyncio.shield(loop.run_in_executor(None, fobj.close))
return await self._sendfile(request, fobj, offset, count)
7 changes: 4 additions & 3 deletions tests/test_web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,16 +579,17 @@ async def test_access_mock_special_resource(
my_special.touch()

real_result = my_special.stat()
real_stat = pathlib.Path.stat
real_stat = os.stat

def mock_stat(self: pathlib.Path, **kwargs: Any) -> os.stat_result:
s = real_stat(self, **kwargs)
def mock_stat(path: Any, **kwargs: Any) -> os.stat_result:
s = real_stat(path, **kwargs)
if os.path.samestat(s, real_result):
mock_mode = S_IFIFO | S_IMODE(s.st_mode)
s = os.stat_result([mock_mode] + list(s)[1:])
return s

monkeypatch.setattr("pathlib.Path.stat", mock_stat)
monkeypatch.setattr("os.stat", mock_stat)

app = web.Application()
app.router.add_static("/", str(tmp_path))
Expand Down
Loading