Skip to content

Commit

Permalink
Update support for @custom_method decorator.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcr-stripe committed Jun 18, 2021
1 parent d92504e commit 9210633
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 13 deletions.
20 changes: 20 additions & 0 deletions stripe/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,23 @@ def _static_request(
return util.convert_to_stripe_object(
response, api_key, stripe_version, stripe_account
)

# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
@classmethod
def _static_request_stream(
cls,
method_,
url_,
api_key=None,
idempotency_key=None,
stripe_version=None,
stripe_account=None,
**params
):
requestor = api_requestor.APIRequestor(
api_key, api_version=stripe_version, account=stripe_account
)
headers = util.populate_headers(idempotency_key)
response, _ = requestor.request_stream(method_, url_, params, headers)
return response
19 changes: 16 additions & 3 deletions stripe/api_resources/abstract/custom_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stripe.six.moves.urllib.parse import quote_plus


def custom_method(name, http_verb, http_path=None):
def custom_method(name, http_verb, http_path=None, is_streaming=False):
if http_verb not in ["get", "post", "delete"]:
raise ValueError(
"Invalid http_verb: %s. Must be one of 'get', 'post' or 'delete'"
Expand All @@ -22,17 +22,30 @@ def custom_method_request(cls, sid, **params):
)
return cls._static_request(http_verb, url, **params)

def custom_method_request_stream(cls, sid, **params):
url = "%s/%s/%s" % (
cls.class_url(),
quote_plus(util.utf8(sid)),
http_path,
)
return cls._static_request_stream(http_verb, url, **params)

if is_streaming:
class_method_impl = classmethod(custom_method_request_stream)
else:
class_method_impl = classmethod(custom_method_request)

existing_method = getattr(cls, name, None)
if existing_method is None:
setattr(cls, name, classmethod(custom_method_request))
setattr(cls, name, class_method_impl)
else:
# If a method with the same name we want to use already exists on
# the class, we assume it's an instance method. In this case, the
# new class method is prefixed with `_cls_`, and the original
# instance method is decorated with `util.class_method_variant` so
# that the new class method is called when the original method is
# called as a class method.
setattr(cls, "_cls_" + name, classmethod(custom_method_request))
setattr(cls, "_cls_" + name, class_method_impl)
instance_method = util.class_method_variant("_cls_" + name)(
existing_method
)
Expand Down
13 changes: 13 additions & 0 deletions stripe/stripe_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ def request(self, method, url, params=None, headers=None):
response, api_key, self.stripe_version, self.stripe_account
)

def request_stream(self, method, url, params=None, headers=None):
if params is None:
params = self._retrieve_params
requestor = api_requestor.APIRequestor(
key=self.api_key,
api_base=self.api_base(),
api_version=self.stripe_version,
account=self.stripe_account,
)
response, _ = requestor.request_stream(method, url, params, headers)

return response

def __repr__(self):
ident_parts = [type(self).__name__]

Expand Down
73 changes: 73 additions & 0 deletions tests/api_resources/abstract/test_custom_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class TestCustomMethod(object):
@stripe.api_resources.abstract.custom_method(
"do_stuff", http_verb="post", http_path="do_the_thing"
)
@stripe.api_resources.abstract.custom_method(
"do_stream_stuff",
http_verb="post",
http_path="do_the_stream_thing",
is_streaming=True,
)
class MyResource(stripe.api_resources.abstract.APIResource):
OBJECT_NAME = "myresource"

Expand All @@ -17,6 +23,11 @@ def do_stuff(self, idempotency_key=None, **params):
self.refresh_from(self.request("post", url, params, headers))
return self

def do_stream_stuff(self, idempotency_key=None, **params):
url = self.instance_url() + "/do_the_stream_thing"
headers = util.populate_headers(idempotency_key)
return self.request_stream("post", url, params, headers)

def test_call_custom_method_class(self, request_mock):
request_mock.stub_request(
"post",
Expand All @@ -32,6 +43,26 @@ def test_call_custom_method_class(self, request_mock):
)
assert obj.thing_done is True

def test_call_custom_stream_method_class(self, request_mock):
request_mock.stub_request_stream(
"post",
"/v1/myresources/mid/do_the_stream_thing",
"response body",
rheaders={"request-id": "req_id"},
)

resp = self.MyResource.do_stream_stuff("mid", foo="bar")

request_mock.assert_requested_stream(
"post", "/v1/myresources/mid/do_the_stream_thing", {"foo": "bar"}
)

body_content = resp.io.read()
if hasattr(body_content, "decode"):
body_content = body_content.decode("utf-8")

assert body_content == "response body"

def test_call_custom_method_class_with_object(self, request_mock):
request_mock.stub_request(
"post",
Expand All @@ -48,6 +79,27 @@ def test_call_custom_method_class_with_object(self, request_mock):
)
assert obj.thing_done is True

def test_call_custom_stream_method_class_with_object(self, request_mock):
request_mock.stub_request_stream(
"post",
"/v1/myresources/mid/do_the_stream_thing",
"response body",
rheaders={"request-id": "req_id"},
)

obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = self.MyResource.do_stream_stuff(obj, foo="bar")

request_mock.assert_requested_stream(
"post", "/v1/myresources/mid/do_the_stream_thing", {"foo": "bar"}
)

body_content = resp.io.read()
if hasattr(body_content, "decode"):
body_content = body_content.decode("utf-8")

assert body_content == "response body"

def test_call_custom_method_instance(self, request_mock):
request_mock.stub_request(
"post",
Expand All @@ -63,3 +115,24 @@ def test_call_custom_method_instance(self, request_mock):
"post", "/v1/myresources/mid/do_the_thing", {"foo": "bar"}
)
assert obj.thing_done is True

def test_call_custom_stream_method_instance(self, request_mock):
request_mock.stub_request_stream(
"post",
"/v1/myresources/mid/do_the_stream_thing",
"response body",
rheaders={"request-id": "req_id"},
)

obj = self.MyResource.construct_from({"id": "mid"}, "mykey")
resp = obj.do_stream_stuff(foo="bar")

request_mock.assert_requested_stream(
"post", "/v1/myresources/mid/do_the_stream_thing", {"foo": "bar"}
)

body_content = resp.io.read()
if hasattr(body_content, "decode"):
body_content = body_content.decode("utf-8")

assert body_content == "response body"
81 changes: 71 additions & 10 deletions tests/request_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import json

import stripe
from stripe import six
from stripe.stripe_response import StripeResponse
from stripe import six, util
from stripe.stripe_response import StripeResponse, StripeStreamResponse


class RequestMock(object):
def __init__(self, mocker):
self._mocker = mocker

self._real_request = stripe.api_requestor.APIRequestor.request
self._real_request_stream = (
stripe.api_requestor.APIRequestor.request_stream
)
self._stub_request_handler = StubRequestHandler()

self.constructor_patcher = self._mocker.patch(
Expand All @@ -26,16 +29,42 @@ def __init__(self, mocker):
autospec=True,
)

self.request_stream_patcher = self._mocker.patch(
"stripe.api_requestor.APIRequestor.request_stream",
side_effect=self._patched_request_stream,
autospec=True,
)

def _patched_request(self, requestor, method, url, *args, **kwargs):
response = self._stub_request_handler.get_response(method, url)
response = self._stub_request_handler.get_response(
method, url, expect_stream=False
)
if response is not None:
return response, stripe.api_key

return self._real_request(requestor, method, url, *args, **kwargs)

def _patched_request_stream(self, requestor, method, url, *args, **kwargs):
response = self._stub_request_handler.get_response(
method, url, expect_stream=True
)
if response is not None:
return response, stripe.api_key

return self._real_request_stream(
requestor, method, url, *args, **kwargs
)

def stub_request(self, method, url, rbody={}, rcode=200, rheaders={}):
self._stub_request_handler.register(
method, url, rbody, rcode, rheaders
method, url, rbody, rcode, rheaders, is_streaming=False
)

def stub_request_stream(
self, method, url, rbody={}, rcode=200, rheaders={}
):
self._stub_request_handler.register(
method, url, rbody, rcode, rheaders, is_streaming=True
)

def assert_api_base(self, expected_api_base):
Expand Down Expand Up @@ -84,6 +113,16 @@ def assert_api_version(self, expected_api_version):
raise AssertionError(msg)

def assert_requested(self, method, url, params=None, headers=None):
self.assert_requested_internal(
self.request_patcher, method, url, params, headers
)

def assert_requested_stream(self, method, url, params=None, headers=None):
self.assert_requested_internal(
self.request_stream_patcher, method, url, params, headers
)

def assert_requested_internal(self, patcher, method, url, params, headers):
params = params or self._mocker.ANY
headers = headers or self._mocker.ANY
called = False
Expand All @@ -99,7 +138,7 @@ def assert_requested(self, method, url, params=None, headers=None):

for args in possible_called_args:
try:
self.request_patcher.assert_called_with(*args)
patcher.assert_called_with(*args)
except AssertionError as e:
exception = e
else:
Expand All @@ -117,23 +156,45 @@ def assert_no_request(self):
)
raise AssertionError(msg)

def assert_no_request_stream(self):
if self.request_stream_patcher.call_count != 0:
msg = (
"Expected 'request_stream' to not have been called. "
"Called %s times." % (self.request_stream_patcher.call_count)
)
raise AssertionError(msg)

def reset_mock(self):
self.request_patcher.reset_mock()
self.request_stream_patcher.reset_mock()


class StubRequestHandler(object):
def __init__(self):
self._entries = {}

def register(self, method, url, rbody={}, rcode=200, rheaders={}):
self._entries[(method, url)] = (rbody, rcode, rheaders)
def register(
self, method, url, rbody={}, rcode=200, rheaders={}, is_streaming=False
):
self._entries[(method, url)] = (rbody, rcode, rheaders, is_streaming)

def get_response(self, method, url):
def get_response(self, method, url, expect_stream=False):
if (method, url) in self._entries:
rbody, rcode, rheaders = self._entries.pop((method, url))
rbody, rcode, rheaders, is_streaming = self._entries.pop(
(method, url)
)

if expect_stream != is_streaming:
return None

if not isinstance(rbody, six.string_types):
rbody = json.dumps(rbody)
stripe_response = StripeResponse(rbody, rcode, rheaders)
if is_streaming:
stripe_response = StripeStreamResponse(
util.io.BytesIO(str.encode(rbody)), rcode, rheaders
)
else:
stripe_response = StripeResponse(rbody, rcode, rheaders)
return stripe_response

return None

0 comments on commit 9210633

Please sign in to comment.