Skip to content

Commit

Permalink
Revert "Allow staticfiles to follow symlinks outside directory (#1377)"
Browse files Browse the repository at this point in the history
This reverts commit d3dccdc.
  • Loading branch information
Kludex authored Jun 10, 2022
1 parent 4519fba commit fab084a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 51 deletions.
44 changes: 20 additions & 24 deletions starlette/staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import stat
import typing
from email.utils import parsedate
from pathlib import Path

import anyio

Expand Down Expand Up @@ -52,7 +51,7 @@ def __init__(
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
if check_dir and directory is not None and not Path(directory).is_dir():
if check_dir and directory is not None and not os.path.isdir(directory):
raise RuntimeError(f"Directory '{directory}' does not exist")

def get_directories(
Expand All @@ -78,9 +77,11 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
assert (
package_directory.is_dir()
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)

Expand All @@ -100,14 +101,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = await self.get_response(path, scope)
await response(scope, receive, send)

def get_path(self, scope: Scope) -> Path:
def get_path(self, scope: Scope) -> str:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return Path(*scope["path"].split("/"))
return os.path.normpath(os.path.join(*scope["path"].split("/")))

async def get_response(self, path: Path, scope: Scope) -> Response:
async def get_response(self, path: str, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
Expand All @@ -130,7 +131,7 @@ async def get_response(self, path: Path, scope: Scope) -> Response:
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = path.joinpath("index.html")
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
Expand All @@ -157,25 +158,20 @@ async def get_response(self, path: Path, scope: Scope) -> Response:
raise HTTPException(status_code=404)

def lookup_path(
self, path: Path
) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
original_path = Path(directory).joinpath(path)
full_path = original_path.resolve()
directory = Path(directory).resolve()
full_path = os.path.realpath(os.path.join(directory, path))
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
continue
try:
stat_result = os.lstat(original_path)
full_path.relative_to(directory)
return full_path, stat_result
except ValueError:
# Allow clients to break out of the static files directory
# if following symlinks.
if stat.S_ISLNK(stat_result.st_mode):
stat_result = os.lstat(full_path)
return full_path, stat_result
return full_path, os.stat(full_path)
except (FileNotFoundError, NotADirectoryError):
continue
return Path(), None
return "", None

def file_response(
self,
Expand Down
29 changes: 2 additions & 27 deletions tests/test_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)

file_path = os.path.join(tmpdir, "example.txt")
with open(file_path, "w") as file:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("outside root dir")

app = StaticFiles(directory=directory)
Expand Down Expand Up @@ -441,28 +441,3 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks_to_break_out_of_dir(
tmp_path: pathlib.Path, test_client_factory
):
statics_path = tmp_path.joinpath("statics")
statics_path.mkdir()

symlink_path = tmp_path.joinpath("symlink")
symlink_path.mkdir()

symlink_file_path = symlink_path.joinpath("index.html")
with open(symlink_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = statics_path.joinpath("index.html")
statics_file_path.symlink_to(symlink_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"

0 comments on commit fab084a

Please sign in to comment.