Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Allow staticfiles to follow symlinks outside directory" #1681

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions starlette/staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import stat
import typing
from email.utils import parsedate
from pathlib import Path

import anyio

Expand Down Expand Up @@ -52,7 +51,7 @@ def __init__(
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
if check_dir and directory is not None and not Path(directory).is_dir():
if check_dir and directory is not None and not os.path.isdir(directory):
raise RuntimeError(f"Directory '{directory}' does not exist")

def get_directories(
Expand All @@ -78,9 +77,11 @@ def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
assert (
package_directory.is_dir()
package_directory = os.path.normpath(
os.path.join(spec.origin, "..", statics_dir)
)
assert os.path.isdir(
package_directory
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)

Expand All @@ -100,14 +101,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = await self.get_response(path, scope)
await response(scope, receive, send)

def get_path(self, scope: Scope) -> Path:
def get_path(self, scope: Scope) -> str:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return Path(*scope["path"].split("/"))
return os.path.normpath(os.path.join(*scope["path"].split("/")))

async def get_response(self, path: Path, scope: Scope) -> Response:
async def get_response(self, path: str, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
Expand All @@ -130,7 +131,7 @@ async def get_response(self, path: Path, scope: Scope) -> Response:
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
index_path = path.joinpath("index.html")
index_path = os.path.join(path, "index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
Expand All @@ -157,25 +158,20 @@ async def get_response(self, path: Path, scope: Scope) -> Response:
raise HTTPException(status_code=404)

def lookup_path(
self, path: Path
) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
original_path = Path(directory).joinpath(path)
full_path = original_path.resolve()
directory = Path(directory).resolve()
full_path = os.path.realpath(os.path.join(directory, path))
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
continue
try:
stat_result = os.lstat(original_path)
full_path.relative_to(directory)
return full_path, stat_result
except ValueError:
# Allow clients to break out of the static files directory
# if following symlinks.
if stat.S_ISLNK(stat_result.st_mode):
stat_result = os.lstat(full_path)
return full_path, stat_result
return full_path, os.stat(full_path)
except (FileNotFoundError, NotADirectoryError):
continue
return Path(), None
return "", None

def file_response(
self,
Expand Down
29 changes: 2 additions & 27 deletions tests/test_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)

file_path = os.path.join(tmpdir, "example.txt")
with open(file_path, "w") as file:
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("outside root dir")

app = StaticFiles(directory=directory)
Expand Down Expand Up @@ -441,28 +441,3 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks_to_break_out_of_dir(
tmp_path: pathlib.Path, test_client_factory
):
statics_path = tmp_path.joinpath("statics")
statics_path.mkdir()

symlink_path = tmp_path.joinpath("symlink")
symlink_path.mkdir()

symlink_file_path = symlink_path.joinpath("index.html")
with open(symlink_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = statics_path.joinpath("index.html")
statics_file_path.symlink_to(symlink_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"