diff --git a/stripe/api_requestor.py b/stripe/api_requestor.py index 9fbb2c3d9..5e4b2b16e 100644 --- a/stripe/api_requestor.py +++ b/stripe/api_requestor.py @@ -6,6 +6,7 @@ import platform import time import uuid +import warnings import stripe from stripe import error, oauth_error, http_client, version, util, six @@ -83,16 +84,31 @@ def __init__( self.api_version = api_version or stripe.api_version self.stripe_account = account + self._default_proxy = None + from stripe import verify_ssl_certs as verify from stripe import proxy - self._client = ( - client - or stripe.default_http_client - or http_client.new_default_http_client( + if client: + self._client = client + elif stripe.default_http_client: + self._client = stripe.default_http_client + if proxy != self._default_proxy: + warnings.warn( + "stripe.proxy was updated after sending a " + "request - this is a no-op. To use a different proxy, " + "set stripe.default_http_client to a new client " + "configured with the proxy." + ) + else: + # If the stripe.default_http_client has not been set by the user + # yet, we'll set it here. This way, we aren't creating a new + # HttpClient for every request. + stripe.default_http_client = http_client.new_default_http_client( verify_ssl_certs=verify, proxy=proxy ) - ) + self._client = stripe.default_http_client + self._default_proxy = proxy self._last_request_metrics = None diff --git a/tests/test_api_requestor.py b/tests/test_api_requestor.py index 43ed20604..8e199a6cb 100644 --- a/tests/test_api_requestor.py +++ b/tests/test_api_requestor.py @@ -231,14 +231,17 @@ def setup_stripe(self): orig_attrs = { "api_key": stripe.api_key, "api_version": stripe.api_version, + "default_http_client": stripe.default_http_client, "enable_telemetry": stripe.enable_telemetry, } stripe.api_key = "sk_test_123" stripe.api_version = "2017-12-14" + stripe.default_http_client = None stripe.enable_telemetry = False yield stripe.api_key = orig_attrs["api_key"] stripe.api_version = orig_attrs["api_version"] + stripe.default_http_client = orig_attrs["default_http_client"] stripe.enable_telemetry = orig_attrs["enable_telemetry"] @pytest.fixture @@ -485,6 +488,25 @@ def test_uses_instance_account( ), ) + def test_sets_default_http_client(self, http_client): + assert not stripe.default_http_client + + stripe.api_requestor.APIRequestor(client=http_client) + + # default_http_client is not populated if a client is provided + assert not stripe.default_http_client + + stripe.api_requestor.APIRequestor() + + # default_http_client is set when no client is specified + assert stripe.default_http_client + + new_default_client = stripe.default_http_client + stripe.api_requestor.APIRequestor() + + # the newly created client is reused + assert stripe.default_http_client == new_default_client + def test_uses_app_info(self, requestor, mock_response, check_call): try: old = stripe.app_info diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 000000000..4ca4e28d2 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,128 @@ +import sys +from threading import Thread +import json +import warnings + +import stripe +import pytest + +if sys.version_info[0] < 3: + from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +else: + from http.server import BaseHTTPRequestHandler, HTTPServer + + +class TestIntegration(object): + @pytest.fixture(autouse=True) + def close_mock_server(self): + yield + if self.mock_server: + self.mock_server.shutdown() + self.mock_server.server_close() + self.mock_server_thread.join() + + @pytest.fixture(autouse=True) + def setup_stripe(self): + orig_attrs = { + "api_base": stripe.api_base, + "api_key": stripe.api_key, + "default_http_client": stripe.default_http_client, + "proxy": stripe.proxy, + } + stripe.api_base = "http://localhost:12111" # stripe-mock + stripe.api_key = "sk_test_123" + stripe.default_http_client = None + stripe.proxy = None + yield + stripe.api_base = orig_attrs["api_base"] + stripe.api_key = orig_attrs["api_key"] + stripe.default_http_client = orig_attrs["default_http_client"] + stripe.proxy = orig_attrs["proxy"] + + def setup_mock_server(self, handler): + # Configure mock server. + # Passing 0 as the port will cause a random free port to be chosen. + self.mock_server = HTTPServer(("localhost", 0), handler) + _, self.mock_server_port = self.mock_server.server_address + + # Start running mock server in a separate thread. + # Daemon threads automatically shut down when the main process exits. + self.mock_server_thread = Thread(target=self.mock_server.serve_forever) + self.mock_server_thread.setDaemon(True) + self.mock_server_thread.start() + + def test_hits_api_base(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + + def do_GET(self): + self.__class__.num_requests += 1 + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + return + + self.setup_mock_server(MockServerRequestHandler) + + stripe.api_base = "http://localhost:%s" % self.mock_server_port + stripe.Balance.retrieve() + assert MockServerRequestHandler.num_requests == 1 + + def test_hits_proxy_through_default_http_client(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + + def do_GET(self): + self.__class__.num_requests += 1 + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + return + + self.setup_mock_server(MockServerRequestHandler) + + stripe.proxy = "http://localhost:%s" % self.mock_server_port + stripe.Balance.retrieve() + assert MockServerRequestHandler.num_requests == 1 + + stripe.proxy = "http://bad-url" + + with warnings.catch_warnings(record=True) as w: + stripe.Balance.retrieve() + assert len(w) == 1 + assert "stripe.proxy was updated after sending a request" in str( + w[0].message + ) + + assert MockServerRequestHandler.num_requests == 2 + + def test_hits_proxy_through_custom_client(self): + class MockServerRequestHandler(BaseHTTPRequestHandler): + num_requests = 0 + + def do_GET(self): + self.__class__.num_requests += 1 + + self.send_response(200) + self.send_header( + "Content-Type", "application/json; charset=utf-8" + ) + self.end_headers() + self.wfile.write(json.dumps({}).encode("utf-8")) + return + + self.setup_mock_server(MockServerRequestHandler) + + stripe.default_http_client = stripe.http_client.new_default_http_client( + proxy="http://localhost:%s" % self.mock_server_port + ) + stripe.Balance.retrieve() + assert MockServerRequestHandler.num_requests == 1