Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support streaming requests, which return an IO object rather than par… #725

Merged
merged 3 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 55 additions & 10 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from stripe import error, oauth_error, http_client, version, util, six
from stripe.multipart_data_generator import MultipartDataGenerator
from stripe.six.moves.urllib.parse import urlencode, urlsplit, urlunsplit
from stripe.stripe_response import StripeResponse
from stripe.stripe_response import StripeResponse, StripeStreamResponse


def _encode_datetime(dttime):
Expand Down Expand Up @@ -117,11 +117,18 @@ def format_app_info(cls, info):

def request(self, method, url, params=None, headers=None):
rbody, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers
method.lower(), url, params, headers, is_streaming=False
)
resp = self.interpret_response(rbody, rcode, rheaders)
return resp, my_api_key

def request_stream(self, method, url, params=None, headers=None):
stream, rcode, rheaders, my_api_key = self.request_raw(
method.lower(), url, params, headers, is_streaming=True
)
resp = self.interpret_streaming_response(stream, rcode, rheaders)
return resp, my_api_key

def handle_error_response(self, rbody, rcode, resp, rheaders):
try:
error_data = resp["error"]
Expand Down Expand Up @@ -273,7 +280,14 @@ def request_headers(self, api_key, method):

return headers

def request_raw(self, method, url, params=None, supplied_headers=None):
def request_raw(
self,
method,
url,
params=None,
supplied_headers=None,
is_streaming=False,
):
"""
Mechanism for issuing an API call
"""
Expand Down Expand Up @@ -340,12 +354,21 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
api_version=self.api_version,
)

rbody, rcode, rheaders = self._client.request_with_retries(
method, abs_url, headers, post_data
)
if is_streaming:
(
rcontent,
rcode,
rheaders,
) = self._client.request_stream_with_retries(
method, abs_url, headers, post_data
)
else:
rcontent, rcode, rheaders = self._client.request_with_retries(
method, abs_url, headers, post_data
)

util.log_info("Stripe API response", path=abs_url, response_code=rcode)
util.log_debug("API response body", body=rbody)
util.log_debug("API response body", body=rcontent)

if "Request-Id" in rheaders:
request_id = rheaders["Request-Id"]
Expand All @@ -354,7 +377,10 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
link=util.dashboard_link(request_id),
)

return rbody, rcode, rheaders, my_api_key
return rcontent, rcode, rheaders, my_api_key

def _should_handle_code_as_error(self, rcode):
return not 200 <= rcode < 300

def interpret_response(self, rbody, rcode, rheaders):
try:
Expand All @@ -369,7 +395,26 @@ def interpret_response(self, rbody, rcode, rheaders):
rcode,
rheaders,
)
if not 200 <= rcode < 300:
if self._should_handle_code_as_error(rcode):
self.handle_error_response(rbody, rcode, resp.data, rheaders)

return resp

def interpret_streaming_response(self, stream, rcode, rheaders):
# Streaming response are handled with minimal processing for the success
# case (ie. we don't want to read the content). When an error is
# received, we need to read from the stream and parse the received JSON,
# treating it like a standard JSON response.
if self._should_handle_code_as_error(rcode):
if hasattr(stream, "getvalue"):
json_content = stream.getvalue()
elif hasattr(stream, "read"):
json_content = stream.read()
else:
raise NotImplementedError(
"HTTP client %s does not return an IOBase object which "
"can be consumed when streaming a response."
)

return self.interpret_response(json_content, rcode, rheaders)
else:
return StripeStreamResponse(stream, rcode, rheaders)
20 changes: 20 additions & 0 deletions stripe/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,23 @@ def _static_request(
return util.convert_to_stripe_object(
response, api_key, stripe_version, stripe_account
)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
@classmethod
def _static_request_stream(
cls,
method_,
url_,
api_key=None,
idempotency_key=None,
stripe_version=None,
stripe_account=None,
**params
):
requestor = api_requestor.APIRequestor(
api_key, api_version=stripe_version, account=stripe_account
)
headers = util.populate_headers(idempotency_key)
response, _ = requestor.request_stream(method_, url_, params, headers)
return response
19 changes: 16 additions & 3 deletions stripe/api_resources/abstract/custom_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stripe.six.moves.urllib.parse import quote_plus


def custom_method(name, http_verb, http_path=None):
def custom_method(name, http_verb, http_path=None, is_streaming=False):
if http_verb not in ["get", "post", "delete"]:
raise ValueError(
"Invalid http_verb: %s. Must be one of 'get', 'post' or 'delete'"
Expand All @@ -22,17 +22,30 @@ def custom_method_request(cls, sid, **params):
)
return cls._static_request(http_verb, url, **params)

def custom_method_request_stream(cls, sid, **params):
url = "%s/%s/%s" % (
cls.class_url(),
quote_plus(util.utf8(sid)),
http_path,
)
return cls._static_request_stream(http_verb, url, **params)

if is_streaming:
class_method_impl = classmethod(custom_method_request_stream)
else:
class_method_impl = classmethod(custom_method_request)

existing_method = getattr(cls, name, None)
if existing_method is None:
setattr(cls, name, classmethod(custom_method_request))
setattr(cls, name, class_method_impl)
else:
# If a method with the same name we want to use already exists on
# the class, we assume it's an instance method. In this case, the
# new class method is prefixed with `_cls_`, and the original
# instance method is decorated with `util.class_method_variant` so
# that the new class method is called when the original method is
# called as a class method.
setattr(cls, "_cls_" + name, classmethod(custom_method_request))
setattr(cls, "_cls_" + name, class_method_impl)
instance_method = util.class_method_variant("_cls_" + name)(
existing_method
)
Expand Down
109 changes: 98 additions & 11 deletions stripe/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
self._thread_local = threading.local()

def request_with_retries(self, method, url, headers, post_data=None):
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=False
)

def request_stream_with_retries(
self, method, url, headers, post_data=None
):
return self._request_with_retries_internal(
method, url, headers, post_data, is_streaming=True
)

def _request_with_retries_internal(
self, method, url, headers, post_data, is_streaming
):
self._add_telemetry_header(headers)

num_retries = 0
Expand All @@ -120,7 +134,12 @@ def request_with_retries(self, method, url, headers, post_data=None):
request_start = _now_ms()

try:
response = self.request(method, url, headers, post_data)
if is_streaming:
response = self.request_stream(
method, url, headers, post_data
)
else:
response = self.request(method, url, headers, post_data)
connection_error = None
except error.APIConnectionError as e:
connection_error = e
Expand Down Expand Up @@ -155,6 +174,11 @@ def request(self, method, url, headers, post_data=None):
"HTTPClient subclasses must implement `request`"
)

def request_stream(self, method, url, headers, post_data=None):
raise NotImplementedError(
"HTTPClient subclasses must implement `request_stream`"
)

def _should_retry(self, response, api_connection_error, num_retries):
if num_retries >= self._max_network_retries():
return False
Expand Down Expand Up @@ -269,6 +293,16 @@ def __init__(self, timeout=80, session=None, **kwargs):
self._timeout = timeout

def request(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=False
)

def request_stream(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=True
)

def _request_internal(self, method, url, headers, post_data, is_streaming):
kwargs = {}
if self._verify_ssl_certs:
kwargs["verify"] = stripe.ca_bundle_path
Expand All @@ -278,6 +312,9 @@ def request(self, method, url, headers, post_data=None):
if self._proxy:
kwargs["proxies"] = self._proxy

if is_streaming:
kwargs["stream"] = True

if getattr(self._thread_local, "session", None) is None:
self._thread_local.session = self._session or requests.Session()

Expand All @@ -301,10 +338,14 @@ def request(self, method, url, headers, post_data=None):
"underlying error was: %s" % (e,)
)

# This causes the content to actually be read, which could cause
# e.g. a socket timeout. TODO: The other fetch methods probably
# are susceptible to the same and should be updated.
content = result.content
if is_streaming:
content = result.raw
else:
# This causes the content to actually be read, which could cause
# e.g. a socket timeout. TODO: The other fetch methods probably
# are susceptible to the same and should be updated.
content = result.content

status_code = result.status_code
except Exception as e:
# Would catch just requests.exceptions.RequestException, but can
Expand Down Expand Up @@ -391,6 +432,16 @@ def __init__(self, verify_ssl_certs=True, proxy=None, deadline=55):
self._deadline = deadline

def request(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=False
)

def request_stream(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=True
)

def _request_internal(self, method, url, headers, post_data, is_streaming):
try:
result = urlfetch.fetch(
url=url,
Expand All @@ -406,7 +457,12 @@ def request(self, method, url, headers, post_data=None):
except urlfetch.Error as e:
self._handle_request_error(e, url)

return result.content, result.status_code, result.headers
if is_streaming:
content = util.io.BytesIO(str.encode(result.content))
else:
content = result.content

return content, result.status_code, result.headers

def _handle_request_error(self, e, url):
if isinstance(e, urlfetch.InvalidURLError):
Expand Down Expand Up @@ -464,6 +520,16 @@ def parse_headers(self, data):
return dict((k.lower(), v) for k, v in six.iteritems(dict(headers)))

def request(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=False
)

def request_stream(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=True
)

def _request_internal(self, method, url, headers, post_data, is_streaming):
b = util.io.BytesIO()
rheaders = util.io.BytesIO()

Expand Down Expand Up @@ -516,11 +582,17 @@ def request(self, method, url, headers, post_data=None):
self._curl.perform()
except pycurl.error as e:
self._handle_request_error(e)
rbody = b.getvalue().decode("utf-8")

if is_streaming:
b.seek(0)
rcontent = b
else:
rcontent = b.getvalue().decode("utf-8")

rcode = self._curl.getinfo(pycurl.RESPONSE_CODE)
headers = self.parse_headers(rheaders.getvalue().decode("utf-8"))

return rbody, rcode, headers
return rcontent, rcode, headers

def _handle_request_error(self, e):
if e.args[0] in [
Expand Down Expand Up @@ -580,6 +652,16 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
self._opener = urllib.request.build_opener(proxy)

def request(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=False
)

def request_stream(self, method, url, headers, post_data=None):
return self._request_internal(
method, url, headers, post_data, is_streaming=True
)

def _request_internal(self, method, url, headers, post_data, is_streaming):
if six.PY3 and isinstance(post_data, six.string_types):
post_data = post_data.encode("utf-8")

Expand All @@ -596,17 +678,22 @@ def request(self, method, url, headers, post_data=None):
if self._opener
else urllib.request.urlopen(req)
)
rbody = response.read()

if is_streaming:
rcontent = response
else:
rcontent = response.read()

rcode = response.code
headers = dict(response.info())
except urllib.error.HTTPError as e:
rcode = e.code
rbody = e.read()
rcontent = e.read()
headers = dict(e.info())
except (urllib.error.URLError, ValueError) as e:
self._handle_request_error(e)
lh = dict((k.lower(), v) for k, v in six.iteritems(dict(headers)))
return rbody, rcode, lh
return rcontent, rcode, lh

def _handle_request_error(self, e):
msg = (
Expand Down
Loading