Skip to content

Commit

Permalink
Run Pyright & begin adding type annotations (#1020)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: anniel-stripe <97691964+anniel-stripe@users.noreply.github.com>
  • Loading branch information
richardm-stripe authored Aug 23, 2023
1 parent 9ce0588 commit 469b3d6
Show file tree
Hide file tree
Showing 47 changed files with 416 additions and 143 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ jobs:

- uses: stripe/openapi/actions/stripe-mock@master

- name: Typecheck with pyright
run: make pyright

- name: Test with pytest
run: make ci-test

Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ coverage.xml
.pytest_cache/

# pyenv
.python-version
.python_version

# Environments
.env
Expand Down
17 changes: 14 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ venv: $(VENV_NAME)/bin/activate
$(VENV_NAME)/bin/activate: setup.py
$(PIP) install --upgrade pip virtualenv
@test -d $(VENV_NAME) || $(PYTHON) -m virtualenv --clear $(VENV_NAME)
${VENV_NAME}/bin/python -m pip install -U pip tox twine -c constraints.txt
${VENV_NAME}/bin/python -m pip install -e .
${VENV_NAME}/bin/python -m pip install -U pip tox twine pyright -c constraints.txt
@touch $(VENV_NAME)/bin/activate

test: venv
Expand All @@ -25,6 +24,18 @@ coveralls: venv
${VENV_NAME}/bin/python -m pip install -U coveralls
@${VENV_NAME}/bin/tox -e coveralls

pyright: venv
# In order for pyright to be able to follow imports, we need "editable_mode=compat" to force setuptools to do
# an editable install via a .pth file mechanism and not "import hooks". See
# the "editable installs" section of https://github.com/microsoft/pyright/blob/main/docs/import-resolution.md#editable-installs

# This command might fail if we're on python 3.6, as versions of pip that
# support python 3.6 don't know about "--config-settings", but in this case
# we don't need to pass config-settings anyway because "editable_mode=compat" just
# means to perform as these old versions of pip already do.
pip install -e . --config-settings editable_mode=compat || pip install -e .
@${VENV_NAME}/bin/pyright

fmt: venv
@${VENV_NAME}/bin/tox -e fmt

Expand All @@ -43,4 +54,4 @@ update-version:

codegen-format: fmt

.PHONY: ci-test clean codegen-format coveralls fmt fmtcheck lint test test-nomock test-travis update-version venv
.PHONY: ci-test clean codegen-format coveralls fmt fmtcheck lint test test-nomock test-travis update-version venv pyright
2 changes: 2 additions & 0 deletions constraints.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# cryptography 40.0.0 deprecates support for Python 3.6 and PyPy3 < 7.3.10
cryptography<40
# TODO (remove later): pyright 1.1.323 introduces false positive errors
pyright<1.1.323
104 changes: 104 additions & 0 deletions flake8_stripe/flake8_stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Hint: if you're developing this plugin, test changes with:
# venv/bin/tox -e lint -r
# so that tox re-installs the plugin from the local directory
import ast
from typing import Iterator, Tuple


class TypingImportsChecker:
name = __name__
version = "0.1.0"

# Rules:
# * typing_extensions v4.1.1 is the latest that supports Python 3.6
# so don't depend on anything from a more recent version than that.
#
# If we need something newer, maybe we can provide it for users on
# newer versions with a conditional import, but we'll cross that
# bridge when we come to it.

# If a symbol exists in both `typing` and `typing_extensions`, which
# should you use? Prefer `typing_extensions` if the symbol available there.
# in 4.1.1. In typing_extensions 4.7.0, `typing_extensions` started re-exporting
# EVERYTHING from `typing` but this is not the case in v4.1.1.
allowed_typing_extensions_imports = [
"Literal",
"NoReturn",
"Protocol",
"TYPE_CHECKING",
"Type",
"TypedDict",
]

allowed_typing_imports = [
"Any",
"ClassVar",
"Optional",
"TypeVar",
"Union",
"cast",
"overload",
"Dict",
]

def __init__(self, tree: ast.AST):
self.tree = tree

intersection = set(self.allowed_typing_imports) & set(
self.allowed_typing_extensions_imports
)
if len(intersection) > 0:
raise AssertionError(
"TypingImportsChecker: allowed_typing_imports and allowed_typing_extensions_imports must not overlap. Both entries contained: %s"
% (intersection)
)

def run(self) -> Iterator[Tuple[int, int, str, type]]:
for node in ast.walk(self.tree):
if isinstance(node, ast.ImportFrom):
if node.module == "typing":
for name in node.names:
if name.name not in self.allowed_typing_imports:
msg = None
if (
name.name
in self.allowed_typing_extensions_imports
):
msg = (
"SPY100 Don't import %s from 'typing', instead import from 'typing_extensions'"
% (name.name)
)
else:
msg = (
"SPY101 Importing %s from 'typing' is prohibited. Do you need to add to the allowlist in flake8_stripe.py?"
% (name.name)
)
yield (
name.lineno,
name.col_offset,
msg,
type(self),
)
elif node.module == "typing_extensions":
for name in node.names:
if (
name.name
not in self.allowed_typing_extensions_imports
):
msg = None
if name.name in self.allowed_typing_imports:
msg = (
"SPY102 Don't import '%s' from 'typing_extensions', instead import from 'typing'"
% (name.name)
)
else:
msg = (
"SPY103 Importing '%s' from 'typing_extensions' is prohibited. Do you need to add to the allowlist in flake8_stripe.py?"
% (name.name)
)
yield (
name.lineno,
name.col_offset,
msg,
type(self),
)
15 changes: 15 additions & 0 deletions flake8_stripe/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from setuptools import setup

setup(
name="flake8_stripe",
version="0.1.0",
py_modules=["flake8_stripe"],
install_requires=[
"flake8>=3.0.0",
],
entry_points={
"flake8.extension": [
"SPY = flake8_stripe:TypingImportsChecker",
],
},
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ exclude = '''
| build/
| dist/
| venv/
| stripe/six.py
)
'''
[tool.pyright]
include=["stripe"]
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
package_data={"stripe": ["data/ca-certificates.crt"]},
zip_safe=False,
install_requires=[
'typing_extensions <= 4.2.0, > 3.7.2; python_version < "3.7"',
'typing_extensions > 3.7.2; python_version >= "3.7"',
'requests >= 2.20; python_version >= "3.0"',
],
python_requires=">=3.6",
Expand Down
2 changes: 1 addition & 1 deletion stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
log = None

# API resources
from stripe.api_resources import * # noqa
from stripe.api_resources import * # pyright: ignore # noqa

# OAuth
from stripe.oauth import OAuth # noqa
Expand Down
7 changes: 5 additions & 2 deletions stripe/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from stripe import api_requestor, error, util
from stripe.stripe_object import StripeObject
from urllib.parse import quote_plus
from typing import ClassVar


class APIResource(StripeObject):
OBJECT_NAME: ClassVar[str]

@classmethod
def retrieve(cls, id, api_key=None, **params):
instance = cls(id, api_key, **params)
Expand Down Expand Up @@ -133,7 +136,7 @@ def _static_request(

if idempotency_key is not None:
headers = {} if headers is None else headers.copy()
headers.update(util.populate_headers(idempotency_key))
headers.update(util.populate_headers(idempotency_key)) # type: ignore

response, api_key = requestor.request(method_, url_, params, headers)
return util.convert_to_stripe_object(
Expand Down Expand Up @@ -172,7 +175,7 @@ def _static_request_stream(

if idempotency_key is not None:
headers = {} if headers is None else headers.copy()
headers.update(util.populate_headers(idempotency_key))
headers.update(util.populate_headers(idempotency_key)) # type: ignore

response, _ = requestor.request_stream(method_, url_, params, headers)
return response
13 changes: 11 additions & 2 deletions stripe/api_resources/abstract/listable_api_resource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

from stripe.api_resources.abstract.api_resource import APIResource
from stripe.api_resources.list_object import ListObject


class ListableAPIResource(APIResource):
Expand All @@ -11,12 +12,20 @@ def auto_paging_iter(cls, *args, **params):
@classmethod
def list(
cls, api_key=None, stripe_version=None, stripe_account=None, **params
):
return cls._static_request(
) -> ListObject:
result = cls._static_request(
"get",
cls.class_url(),
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
)

if not isinstance(result, ListObject):
raise TypeError(
"Expected list object from API, got %s"
% (type(result).__name__,)
)

return result
10 changes: 9 additions & 1 deletion stripe/api_resources/abstract/searchable_api_resource.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function

from stripe.api_resources.abstract.api_resource import APIResource
from stripe.api_resources.search_result_object import SearchResultObject


class SearchableAPIResource(APIResource):
Expand All @@ -13,14 +14,21 @@ def _search(
stripe_account=None,
**params
):
return cls._static_request(
ret = cls._static_request(
"get",
search_url,
api_key=api_key,
stripe_version=stripe_version,
stripe_account=stripe_account,
params=params,
)
if not isinstance(ret, SearchResultObject):
raise TypeError(
"Expected search result from API, got %s"
% (type(ret).__name__,)
)

return ret

@classmethod
def search(cls, *args, **kwargs):
Expand Down
16 changes: 13 additions & 3 deletions stripe/api_resources/abstract/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
from stripe import error
from urllib.parse import quote_plus

from typing import TypeVar, ClassVar, Any
from typing_extensions import Protocol

class APIResourceTestHelpers:
from stripe.api_resources.abstract.api_resource import APIResource

T = TypeVar("T", bound=APIResource)


class APIResourceTestHelpers(Protocol[T]):
"""
The base type for the TestHelper nested classes.
Handles request URL generation for test_helper custom methods.
Expand All @@ -15,6 +22,9 @@ class Foo(APIResource):
class TestHelpers(APIResourceTestHelpers):
"""

_resource_cls: ClassVar[Any]
resource: T

def __init__(self, resource):
self.resource = resource

Expand All @@ -35,11 +45,11 @@ def class_url(cls):
)
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
base = cls._resource_cls.OBJECT_NAME.replace(".", "/")
base = cls._resource_cls.OBJECT_NAME.replace(".", "/") # type: ignore
return "/v1/test_helpers/%ss" % (base,)

def instance_url(self):
id = self.resource.get("id")
id = getattr(self.resource, "id", None)

if not isinstance(id, str):
raise error.InvalidRequestError(
Expand Down
21 changes: 20 additions & 1 deletion stripe/api_resources/abstract/verify_mixin.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
from __future__ import absolute_import, division, print_function

from typing import Optional
from typing_extensions import Protocol

from stripe.stripe_object import StripeObject


class _Verifiable(Protocol):
def instance_url(self) -> str:
...

def _request(
self,
method: str,
url: str,
idempotency_key: Optional[str],
params: dict,
) -> StripeObject:
...


class VerifyMixin(object):
def verify(self, idempotency_key=None, **params):
def verify(self: _Verifiable, idempotency_key=None, **params):
url = self.instance_url() + "/verify"
return self._request(
"post", url, idempotency_key=idempotency_key, params=params
Expand Down
2 changes: 1 addition & 1 deletion stripe/api_resources/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def instance_url(self):
return self._build_instance_url(self.get("id"))

def deauthorize(self, **params):
params["stripe_user_id"] = self.id
params["stripe_user_id"] = self.id # type: ignore
return oauth.OAuth.deauthorize(**params)

def serialize(self, previous):
Expand Down
2 changes: 1 addition & 1 deletion stripe/api_resources/application_fee_refund.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def modify(cls, fee, sid, **params):
return cls._static_request("post", url, params=params)

def instance_url(self):
return self._build_instance_url(self.fee, self.id)
return self._build_instance_url(self.fee, self.id) # type: ignore

@classmethod
def retrieve(cls, id, api_key=None, **params):
Expand Down
Loading

0 comments on commit 469b3d6

Please sign in to comment.