From a799da72730b1da3c1030367bb05be482f22fbbd Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 10 Jun 2022 15:21:44 +0200 Subject: [PATCH] Allow StaticFiles follow symlinks Add test to prevent path traversal Update requirements.txt Update tests --- requirements.txt | 2 +- starlette/staticfiles.py | 2 +- tests/test_staticfiles.py | 49 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index c834ac8b8..d016660c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ trio==0.19.0 # Documentation mkdocs==1.3.0 -mkdocs-material==8.2.8 +mkdocs-material==8.3.3 mkautodoc==0.1.0 # Packaging diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index d09630f35..83c98466c 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -161,7 +161,7 @@ def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: - full_path = os.path.realpath(os.path.join(directory, path)) + full_path = os.path.abspath(os.path.join(directory, path)) directory = os.path.realpath(directory) if os.path.commonprefix([full_path, directory]) != directory: # Don't allow misbehaving clients to break out of the static files diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 7d13a0522..382716d49 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -441,3 +441,52 @@ def mock_timeout(*args, **kwargs): response = client.get("/example.txt") assert response.status_code == 500 assert response.text == "Internal Server Error" + + +def test_staticfiles_follows_symlinks(tmpdir, test_client_factory): + statics_path = os.path.join(tmpdir, "statics") + os.mkdir(statics_path) + + symlink_path = os.path.join(tmpdir, "symlink") + os.mkdir(symlink_path) + + symlink_file_path = os.path.join(symlink_path, "index.html") + with open(symlink_file_path, "w") as file: + file.write("

Hello

") + + statics_file_path = os.path.join(statics_path, "index.html") + os.symlink(symlink_file_path, statics_file_path) + + app = StaticFiles(directory=statics_path) + client = test_client_factory(app) + + response = client.get("/index.html") + assert response.url == "http://testserver/index.html" + assert response.status_code == 200 + assert response.text == "

Hello

" + + +def test_staticfiles_disallows_path_traversal_with_symlinks(tmpdir): + statics_path = os.path.join(tmpdir, "statics") + os.mkdir(statics_path) + + symlink_path = os.path.join(tmpdir, "symlink") + os.mkdir(symlink_path) + + symlink_file_path = os.path.join(symlink_path, "index.html") + with open(symlink_file_path, "w") as file: + file.write("

Hello

") + + temp_path = os.path.join(tmpdir, "index.html") + os.symlink(symlink_file_path, temp_path) + + app = StaticFiles(directory=statics_path) + # We can't test this with 'requests', so we test the app directly here. + path = app.get_path({"path": "/../index.html"}) + scope = {"method": "GET"} + + with pytest.raises(HTTPException) as exc_info: + anyio.run(app.get_response, path, scope) + + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "Not Found"