diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index fc63e91b6..c7fd5b77e 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -15,17 +15,24 @@ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9 self.compresslevel = compresslevel async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] == "http": # pragma: no branch - headers = Headers(scope=scope) - if "gzip" in headers.get("Accept-Encoding", ""): - responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) - await responder(scope, receive, send) - return - await self.app(scope, receive, send) + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + headers = Headers(scope=scope) + responder: ASGIApp + if "gzip" in headers.get("Accept-Encoding", ""): + responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) + else: + responder = IdentityResponder(self.app, self.minimum_size) -class GZipResponder: - def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: + await responder(scope, receive, send) + + +class IdentityResponder: + content_encoding: str + + def __init__(self, app: ASGIApp, minimum_size: int) -> None: self.app = app self.minimum_size = minimum_size self.send: Send = unattached_send @@ -33,15 +40,12 @@ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> N self.started = False self.content_encoding_set = False self.content_type_is_excluded = False - self.gzip_buffer = io.BytesIO() - self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.send = send - with self.gzip_buffer, self.gzip_file: - await self.app(scope, receive, self.send_with_gzip) + await self.app(scope, receive, self.send_with_compression) - async def send_with_gzip(self, message: Message) -> None: + async def send_with_compression(self, message: Message) -> None: message_type = message["type"] if message_type == "http.response.start": # Don't send the initial message until we've determined how to @@ -60,53 +64,78 @@ async def send_with_gzip(self, message: Message) -> None: body = message.get("body", b"") more_body = message.get("more_body", False) if len(body) < self.minimum_size and not more_body: - # Don't apply GZip to small outgoing responses. + # Don't apply compression to small outgoing responses. await self.send(self.initial_message) await self.send(message) elif not more_body: - # Standard GZip response. - self.gzip_file.write(body) - self.gzip_file.close() - body = self.gzip_buffer.getvalue() + # Standard response. + body = self.apply_compression(body, more_body=False) headers = MutableHeaders(raw=self.initial_message["headers"]) - headers["Content-Encoding"] = "gzip" - headers["Content-Length"] = str(len(body)) headers.add_vary_header("Accept-Encoding") - message["body"] = body + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + headers["Content-Length"] = str(len(body)) + message["body"] = body await self.send(self.initial_message) await self.send(message) else: - # Initial body in streaming GZip response. + # Initial body in streaming response. + body = self.apply_compression(body, more_body=True) + headers = MutableHeaders(raw=self.initial_message["headers"]) - headers["Content-Encoding"] = "gzip" headers.add_vary_header("Accept-Encoding") - del headers["Content-Length"] - - self.gzip_file.write(body) - message["body"] = self.gzip_buffer.getvalue() - self.gzip_buffer.seek(0) - self.gzip_buffer.truncate() + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + del headers["Content-Length"] + message["body"] = body await self.send(self.initial_message) await self.send(message) - elif message_type == "http.response.body": # pragma: no branch - # Remaining body in streaming GZip response. + # Remaining body in streaming response. body = message.get("body", b"") more_body = message.get("more_body", False) - self.gzip_file.write(body) - if not more_body: - self.gzip_file.close() - - message["body"] = self.gzip_buffer.getvalue() - self.gzip_buffer.seek(0) - self.gzip_buffer.truncate() + message["body"] = self.apply_compression(body, more_body=more_body) await self.send(message) + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + """Apply compression on the response body. + + If more_body is False, any compression file should be closed. If it + isn't, it won't be closed automatically until all background tasks + complete. + """ + return body + + +class GZipResponder(IdentityResponder): + content_encoding = "gzip" + + def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: + super().__init__(app, minimum_size) + + self.gzip_buffer = io.BytesIO() + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + with self.gzip_buffer, self.gzip_file: + await super().__call__(scope, receive, send) + + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + self.gzip_file.write(body) + if not more_body: + self.gzip_file.close() + + body = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + return body + async def unattached_send(message: Message) -> typing.NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 38a4e1e35..48ded6ae8 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware @@ -21,6 +23,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" + assert response.headers["Vary"] == "Accept-Encoding" assert int(response.headers["Content-Length"]) < 4000 @@ -38,6 +41,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "x" * 4000 assert "Content-Encoding" not in response.headers + assert response.headers["Vary"] == "Accept-Encoding" assert int(response.headers["Content-Length"]) == 4000 @@ -57,6 +61,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 assert response.text == "OK" assert "Content-Encoding" not in response.headers + assert "Vary" not in response.headers assert int(response.headers["Content-Length"]) == 2 @@ -79,6 +84,30 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" + assert response.headers["Vary"] == "Accept-Encoding" + assert "Content-Length" not in response.headers + + +def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> StreamingResponse: + async def generator(bytes: bytes, count: int) -> ContentStream: + for index in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse(streaming, status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(GZipMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "identity"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert response.headers["Vary"] == "Accept-Encoding" assert "Content-Length" not in response.headers @@ -103,6 +132,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.status_code == 200 assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" + assert "Vary" not in response.headers assert "Content-Length" not in response.headers