Skip to content

Commit

Permalink
Allow StaticFiles follow symlinks
Browse files Browse the repository at this point in the history
Add test to prevent path traversal

Update requirements.txt

Update tests
  • Loading branch information
aminalaee committed Jun 11, 2022
1 parent daf2913 commit a799da7
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ trio==0.19.0

# Documentation
mkdocs==1.3.0
mkdocs-material==8.2.8
mkdocs-material==8.3.3
mkautodoc==0.1.0

# Packaging
Expand Down
2 changes: 1 addition & 1 deletion starlette/staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def lookup_path(
self, path: str
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
full_path = os.path.realpath(os.path.join(directory, path))
full_path = os.path.abspath(os.path.join(directory, path))
directory = os.path.realpath(directory)
if os.path.commonprefix([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
Expand Down
49 changes: 49 additions & 0 deletions tests/test_staticfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,52 @@ def mock_timeout(*args, **kwargs):
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"


def test_staticfiles_follows_symlinks(tmpdir, test_client_factory):
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)

symlink_path = os.path.join(tmpdir, "symlink")
os.mkdir(symlink_path)

symlink_file_path = os.path.join(symlink_path, "index.html")
with open(symlink_file_path, "w") as file:
file.write("<h1>Hello</h1>")

statics_file_path = os.path.join(statics_path, "index.html")
os.symlink(symlink_file_path, statics_file_path)

app = StaticFiles(directory=statics_path)
client = test_client_factory(app)

response = client.get("/index.html")
assert response.url == "http://testserver/index.html"
assert response.status_code == 200
assert response.text == "<h1>Hello</h1>"


def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir):
statics_path = os.path.join(tmpdir, "statics")
os.mkdir(statics_path)

symlink_path = os.path.join(tmpdir, "symlink")
os.mkdir(symlink_path)

symlink_file_path = os.path.join(symlink_path, "index.html")
with open(symlink_file_path, "w") as file:
file.write("<h1>Hello</h1>")

temp_path = os.path.join(tmpdir, "index.html")
os.symlink(symlink_file_path, temp_path)

app = StaticFiles(directory=statics_path)
# We can't test this with 'requests', so we test the app directly here.
path = app.get_path({"path": "/../index.html"})
scope = {"method": "GET"}

with pytest.raises(HTTPException) as exc_info:
anyio.run(app.get_response, path, scope)

assert exc_info.value.status_code == 404
assert exc_info.value.detail == "Not Found"

0 comments on commit a799da7

Please sign in to comment.