diff --git a/cachecontrol/controller.py b/cachecontrol/controller.py index e411a32..9555a0b 100644 --- a/cachecontrol/controller.py +++ b/cachecontrol/controller.py @@ -164,7 +164,7 @@ def cached_request(self, request): # with cache busting headers as usual (ie no-cache). if int(resp.status) in PERMANENT_REDIRECT_STATUSES: msg = ( - 'Returning cached permanent redirect response ' + "Returning cached permanent redirect response " "(ignoring date and etag information)" ) logger.debug(msg) @@ -268,10 +268,8 @@ def cache_response(self, request, response, body=None, status_codes=None): response_headers = CaseInsensitiveDict(response.headers) - if 'date' in response_headers: - date = calendar.timegm( - parsedate_tz(response_headers['date']) - ) + if "date" in response_headers: + date = calendar.timegm(parsedate_tz(response_headers["date"])) else: date = 0 @@ -319,55 +317,57 @@ def cache_response(self, request, response, body=None, status_codes=None): # If we've been given an etag, then keep the response if self.cache_etags and "etag" in response_headers: expires_time = 0 - if response_headers.get('expires'): - expires = parsedate_tz(response_headers['expires']) + if response_headers.get("expires"): + expires = parsedate_tz(response_headers["expires"]) if expires is not None: expires_time = calendar.timegm(expires) - date expires_time = max(expires_time, 14 * 86400) - logger.debug('etag object cached for {0} seconds'.format(expires_time)) + logger.debug("etag object cached for {0} seconds".format(expires_time)) logger.debug("Caching due to etag") self.cache.set( cache_url, self.serializer.dumps(request, response, body), - expires=expires_time + expires=expires_time, ) # Add to the cache any permanent redirects. We do this before looking # that the Date headers. elif int(response.status) in PERMANENT_REDIRECT_STATUSES: logger.debug("Caching permanent redirect") - self.cache.set(cache_url, self.serializer.dumps(request, response, b'')) + self.cache.set(cache_url, self.serializer.dumps(request, response, b"")) # Add to the cache if the response headers demand it. If there # is no date header then we can't do anything about expiring # the cache. elif "date" in response_headers: - date = calendar.timegm( - parsedate_tz(response_headers['date']) - ) + date = calendar.timegm(parsedate_tz(response_headers["date"])) # cache when there is a max-age > 0 if "max-age" in cc and cc["max-age"] > 0: logger.debug("Caching b/c date exists and max-age > 0") - expires_time = cc['max-age'] + expires_time = cc["max-age"] self.cache.set( cache_url, self.serializer.dumps(request, response, body), - expires=expires_time + expires=expires_time, ) # If the request can expire, it means we should cache it # in the meantime. elif "expires" in response_headers: if response_headers["expires"]: - expires = parsedate_tz(response_headers['expires']) + expires = parsedate_tz(response_headers["expires"]) if expires is not None: expires_time = calendar.timegm(expires) - date else: expires_time = None - logger.debug('Caching b/c of expires header. expires in {0} seconds'.format(expires_time)) + logger.debug( + "Caching b/c of expires header. expires in {0} seconds".format( + expires_time + ) + ) self.cache.set( cache_url, self.serializer.dumps(request, response, body=body), @@ -410,7 +410,6 @@ def update_cached_response(self, request, response): cached_response.status = 200 # update our cache - body = cached_response.read(decode_content=False) - self.cache.set(cache_url, self.serializer.dumps(request, cached_response, body)) + self.cache.set(cache_url, self.serializer.dumps(request, cached_response)) return cached_response diff --git a/cachecontrol/serialize.py b/cachecontrol/serialize.py index 4e49a90..0cb1c8f 100644 --- a/cachecontrol/serialize.py +++ b/cachecontrol/serialize.py @@ -25,10 +25,16 @@ def _b64_decode_str(s): class Serializer(object): - - def dumps(self, request, response, body): + def dumps(self, request, response, body=None): response_headers = CaseInsensitiveDict(response.headers) + if body is None: + # When a body isn't passed in, we'll read the response. We + # also update the response with a new file handler to be + # sure it acts as though it was never read. + body = response.read(decode_content=False) + response._fp = io.BytesIO(body) + # NOTE: This is all a bit weird, but it's really important that on # Python 2.x these objects are unicode and not str, even when # they contain only ascii. The problem here is that msgpack diff --git a/tests/issue_263.py b/tests/issue_263.py new file mode 100644 index 0000000..66075fe --- /dev/null +++ b/tests/issue_263.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +import sys + +import cachecontrol +import requests +from cachecontrol.cache import DictCache +from cachecontrol.heuristics import BaseHeuristic + +import logging + +clogger = logging.getLogger("cachecontrol") +clogger.addHandler(logging.StreamHandler()) +clogger.setLevel(logging.DEBUG) + + +from pprint import pprint + + +class NoAgeHeuristic(BaseHeuristic): + def update_headers(self, response): + if "cache-control" in response.headers: + del response.headers["cache-control"] + + +cache_adapter = cachecontrol.CacheControlAdapter( + DictCache(), cache_etags=True, heuristic=NoAgeHeuristic() +) + + +session = requests.Session() +session.mount("https://", cache_adapter) + + +def log_resp(resp): + return + + print(f"{resp.status_code} {resp.request.method}") + for k, v in response.headers.items(): + print(f"{k}: {v}") + + +for i in range(2): + response = session.get( + "https://api.github.com/repos/sigmavirus24/github3.py/pulls/1033" + ) + log_resp(response) + print(f"Content length: {len(response.content)}") + print(response.from_cache) + if len(response.content) == 0: + sys.exit(1) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 59771c5..4301be4 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -12,7 +12,6 @@ class TestSerializer(object): - def setup(self): self.serializer = Serializer() self.response_data = { @@ -93,7 +92,9 @@ def test_read_latest_version_streamable(self, url): original_resp = requests.get(url, stream=True) req = original_resp.request - resp = self.serializer.loads(req, self.serializer.dumps(req, original_resp.raw, original_resp.content)) + resp = self.serializer.loads( + req, self.serializer.dumps(req, original_resp.raw, original_resp.content) + ) assert resp.read() @@ -120,3 +121,17 @@ def test_no_vary_header(self, url): assert self.serializer.loads( req, self.serializer.dumps(req, original_resp.raw, data) ) + + def test_no_body_creates_response_file_handle_on_dumps(self, url): + original_resp = requests.get(url, stream=True) + data = None + req = original_resp.request + + assert self.serializer.loads( + req, self.serializer.dumps(req, original_resp.raw, data) + ) + + # By passing in data=None it will force a read of the file + # handle. Reading it again proves we're resetting the internal + # file handle with a buffer. + assert original_resp.raw.read()