diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 13d375e8..a16e9b50 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,10 @@ jobs: run: make run_tests - name: Upload Coverage uses: codecov/codecov-action@v3 + - name: Run Tests with pydantic v1 + run: | + pip install pydantic==1.10.12 + make tests_only publish: needs: test diff --git a/gotrue/_async/api.py b/gotrue/_async/api.py index f23d26db..77fbf84a 100644 --- a/gotrue/_async/api.py +++ b/gotrue/_async/api.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import TypeAdapter +from pydantic import parse_obj_as from ..exceptions import APIError from ..helpers import check_response, encode_uri_component @@ -94,7 +94,7 @@ async def list_users(self) -> List[User]: raise APIError("No users found in response", 400) if not isinstance(users, list): raise APIError("Expected a list of users", 400) - return TypeAdapter(List[User]).validate_python(users) + return parse_obj_as(List[User], users) async def sign_up_with_email( self, diff --git a/gotrue/_async/client.py b/gotrue/_async/client.py index 6c625c12..0578033a 100644 --- a/gotrue/_async/client.py +++ b/gotrue/_async/client.py @@ -10,6 +10,7 @@ from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY from ..exceptions import APIError +from ..helpers import model_dump, model_validate from ..types import ( AuthChangeEvent, CookieOptions, @@ -560,7 +561,7 @@ async def _recover_common(self) -> Optional[Tuple[Session, int, int]]: and session_raw and isinstance(session_raw, dict) ): - session = Session.model_validate(session_raw) + session = model_validate(Session, session_raw) expires_at = int(expires_at_raw) time_now = round(time()) return session, expires_at, time_now @@ -628,7 +629,7 @@ async def _save_session(self, *, session: Session) -> None: await self._persist_session(session=session) async def _persist_session(self, *, session: Session) -> None: - data = {"session": session.model_dump(), "expires_at": session.expires_at} + data = {"session": model_dump(session), "expires_at": session.expires_at} await self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str)) async def _remove_session(self) -> None: diff --git a/gotrue/_async/gotrue_admin_api.py b/gotrue/_async/gotrue_admin_api.py index f8b0b422..08b1f770 100644 --- a/gotrue/_async/gotrue_admin_api.py +++ b/gotrue/_async/gotrue_admin_api.py @@ -1,8 +1,9 @@ from __future__ import annotations +from functools import partial from typing import Dict, List, Union -from ..helpers import parse_link_response, parse_user_response +from ..helpers import model_validate, parse_link_response, parse_user_response from ..http_clients import AsyncClient from ..types import ( AdminUserAttributes, @@ -109,7 +110,7 @@ async def list_users(self) -> List[User]: return await self._request( "GET", "admin/users", - xform=lambda data: [User.model_validate(user) for user in data["users"]] + xform=lambda data: [model_validate(User, user) for user in data["users"]] if "users" in data else [], ) @@ -161,7 +162,7 @@ async def _list_factors( return await self._request( "GET", f"admin/users/{params.get('user_id')}/factors", - xform=AuthMFAAdminListFactorsResponse.model_validate, + xform=partial(model_validate, AuthMFAAdminListFactorsResponse), ) async def _delete_factor( @@ -171,5 +172,5 @@ async def _delete_factor( return await self._request( "DELETE", f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}", - xform=AuthMFAAdminDeleteFactorResponse.model_validate, + xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse), ) diff --git a/gotrue/_async/gotrue_base_api.py b/gotrue/_async/gotrue_base_api.py index 4740ea7b..f6ce7f0c 100644 --- a/gotrue/_async/gotrue_base_api.py +++ b/gotrue/_async/gotrue_base_api.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from typing_extensions import Literal, Self -from ..helpers import handle_exception +from ..helpers import handle_exception, model_dump from ..http_clients import AsyncClient T = TypeVar("T") @@ -108,7 +108,7 @@ async def _request( url, headers=headers, params=query, - json=body.model_dump() if isinstance(body, BaseModel) else body, + json=model_dump(body) if isinstance(body, BaseModel) else body, ) response.raise_for_status() result = response if no_resolve_json else response.json() diff --git a/gotrue/_async/gotrue_client.py b/gotrue/_async/gotrue_client.py index dd9ea1b3..f298a90f 100644 --- a/gotrue/_async/gotrue_client.py +++ b/gotrue/_async/gotrue_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from json import loads from time import time from typing import Callable, Dict, List, Tuple, Union @@ -20,7 +21,14 @@ AuthRetryableError, AuthSessionMissingError, ) -from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response +from ..helpers import ( + decode_jwt_payload, + model_dump, + model_dump_json, + model_validate, + parse_auth_response, + parse_user_response, +) from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( @@ -531,7 +539,7 @@ async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: "factors", body=params, jwt=session.access_token, - xform=AuthMFAEnrollResponse.model_validate, + xform=partial(model_validate, AuthMFAEnrollResponse), ) if response.totp.qr_code: response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" @@ -545,7 +553,7 @@ async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeRespon "POST", f"factors/{params.get('factor_id')}/challenge", jwt=session.access_token, - xform=AuthMFAChallengeResponse.model_validate, + xform=partial(model_validate, AuthMFAChallengeResponse), ) async def _challenge_and_verify( @@ -574,9 +582,9 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: f"factors/{params.get('factor_id')}/verify", body=params, jwt=session.access_token, - xform=AuthMFAVerifyResponse.model_validate, + xform=partial(model_validate, AuthMFAVerifyResponse), ) - session = Session.model_validate(response.model_dump()) + session = model_validate(Session, model_dump(response)) await self._save_session(session) self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) return response @@ -589,7 +597,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=AuthMFAUnenrollResponse.model_validate, + xform=partial(AuthMFAUnenrollResponse, model_validate), ) async def _list_factors(self) -> AuthMFAListFactorsResponse: @@ -751,7 +759,7 @@ async def _save_session(self, session: Session) -> None: value = (expire_in - refresh_duration_before_expires) * 1000 await self._start_auto_refresh_token(value) if self._persist_session and session.expires_at: - await self._storage.set_item(self._storage_key, session.model_dump_json()) + await self._storage.set_item(self._storage_key, model_dump_json(session)) async def _start_auto_refresh_token(self, value: float) -> None: if self._refresh_token_timer: @@ -808,7 +816,7 @@ def _get_valid_session( except ValueError: return None try: - return Session.model_validate(data) + return model_validate(Session, data) except Exception: return None diff --git a/gotrue/_sync/api.py b/gotrue/_sync/api.py index aa4eaf20..abbdc480 100644 --- a/gotrue/_sync/api.py +++ b/gotrue/_sync/api.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union -from pydantic import TypeAdapter +from pydantic import parse_obj_as from ..exceptions import APIError from ..helpers import check_response, encode_uri_component @@ -94,7 +94,7 @@ def list_users(self) -> List[User]: raise APIError("No users found in response", 400) if not isinstance(users, list): raise APIError("Expected a list of users", 400) - return TypeAdapter(List[User]).validate_python(users) + return parse_obj_as(List[User], users) def sign_up_with_email( self, diff --git a/gotrue/_sync/client.py b/gotrue/_sync/client.py index 849a3768..d468cbc2 100644 --- a/gotrue/_sync/client.py +++ b/gotrue/_sync/client.py @@ -10,6 +10,7 @@ from ..constants import COOKIE_OPTIONS, DEFAULT_HEADERS, GOTRUE_URL, STORAGE_KEY from ..exceptions import APIError +from ..helpers import model_dump, model_validate from ..types import ( AuthChangeEvent, CookieOptions, @@ -556,7 +557,7 @@ def _recover_common(self) -> Optional[Tuple[Session, int, int]]: and session_raw and isinstance(session_raw, dict) ): - session = Session.model_validate(session_raw) + session = model_validate(Session, session_raw) expires_at = int(expires_at_raw) time_now = round(time()) return session, expires_at, time_now @@ -620,7 +621,7 @@ def _save_session(self, *, session: Session) -> None: self._persist_session(session=session) def _persist_session(self, *, session: Session) -> None: - data = {"session": session.model_dump(), "expires_at": session.expires_at} + data = {"session": model_dump(session), "expires_at": session.expires_at} self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str)) def _remove_session(self) -> None: diff --git a/gotrue/_sync/gotrue_admin_api.py b/gotrue/_sync/gotrue_admin_api.py index b7966217..75f7dba2 100644 --- a/gotrue/_sync/gotrue_admin_api.py +++ b/gotrue/_sync/gotrue_admin_api.py @@ -1,8 +1,9 @@ from __future__ import annotations +from functools import partial from typing import Dict, List, Union -from ..helpers import parse_link_response, parse_user_response +from ..helpers import model_validate, parse_link_response, parse_user_response from ..http_clients import SyncClient from ..types import ( AdminUserAttributes, @@ -109,7 +110,7 @@ def list_users(self) -> List[User]: return self._request( "GET", "admin/users", - xform=lambda data: [User.model_validate(user) for user in data["users"]] + xform=lambda data: [model_validate(User, user) for user in data["users"]] if "users" in data else [], ) @@ -161,7 +162,7 @@ def _list_factors( return self._request( "GET", f"admin/users/{params.get('user_id')}/factors", - xform=AuthMFAAdminListFactorsResponse.model_validate, + xform=partial(model_validate, AuthMFAAdminListFactorsResponse), ) def _delete_factor( @@ -171,5 +172,5 @@ def _delete_factor( return self._request( "DELETE", f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}", - xform=AuthMFAAdminDeleteFactorResponse.model_validate, + xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse), ) diff --git a/gotrue/_sync/gotrue_base_api.py b/gotrue/_sync/gotrue_base_api.py index d459f014..2a9b97fc 100644 --- a/gotrue/_sync/gotrue_base_api.py +++ b/gotrue/_sync/gotrue_base_api.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from typing_extensions import Literal, Self -from ..helpers import handle_exception +from ..helpers import handle_exception, model_dump from ..http_clients import SyncClient T = TypeVar("T") @@ -108,7 +108,7 @@ def _request( url, headers=headers, params=query, - json=body.model_dump() if isinstance(body, BaseModel) else body, + json=model_dump(body) if isinstance(body, BaseModel) else body, ) response.raise_for_status() result = response if no_resolve_json else response.json() diff --git a/gotrue/_sync/gotrue_client.py b/gotrue/_sync/gotrue_client.py index f174ecd6..cc2a4f9a 100644 --- a/gotrue/_sync/gotrue_client.py +++ b/gotrue/_sync/gotrue_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from json import loads from time import time from typing import Callable, Dict, List, Tuple, Union @@ -20,7 +21,13 @@ AuthRetryableError, AuthSessionMissingError, ) -from ..helpers import decode_jwt_payload, parse_auth_response, parse_user_response +from ..helpers import ( + decode_jwt_payload, + model_dump, + model_validate, + parse_auth_response, + parse_user_response, +) from ..http_clients import SyncClient from ..timer import Timer from ..types import ( @@ -529,7 +536,7 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: "factors", body=params, jwt=session.access_token, - xform=AuthMFAEnrollResponse.model_validate, + xform=partial(model_validate, AuthMFAEnrollResponse), ) if response.totp.qr_code: response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" @@ -543,7 +550,7 @@ def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: "POST", f"factors/{params.get('factor_id')}/challenge", jwt=session.access_token, - xform=AuthMFAChallengeResponse.model_validate, + xform=partial(model_validate, AuthMFAChallengeResponse), ) def _challenge_and_verify( @@ -572,9 +579,9 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: f"factors/{params.get('factor_id')}/verify", body=params, jwt=session.access_token, - xform=AuthMFAVerifyResponse.model_validate, + xform=partial(model_validate, AuthMFAVerifyResponse), ) - session = Session.model_validate(response.model_dump()) + session = model_validate(Session, model_dump(response)) self._save_session(session) self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) return response @@ -587,7 +594,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=AuthMFAUnenrollResponse.model_validate, + xform=partial(model_validate, AuthMFAUnenrollResponse), ) def _list_factors(self) -> AuthMFAListFactorsResponse: @@ -806,7 +813,7 @@ def _get_valid_session( except ValueError: return None try: - return Session.model_validate(data) + return model_validate(Session, data) except Exception: return None diff --git a/gotrue/helpers.py b/gotrue/helpers.py index dfce1aaf..79fa33c7 100644 --- a/gotrue/helpers.py +++ b/gotrue/helpers.py @@ -2,9 +2,10 @@ from base64 import b64decode from json import loads -from typing import Any, Union, cast +from typing import Any, Dict, Type, TypeVar, Union, cast from httpx import HTTPStatusError +from pydantic import BaseModel from .errors import AuthApiError, AuthError, AuthRetryableError, AuthUnknownError from .types import ( @@ -16,6 +17,39 @@ UserResponse, ) +TBaseModel = TypeVar("TBaseModel", bound=BaseModel) + + +def model_validate(model: Type[TBaseModel], contents) -> TBaseModel: + """Compatibility layer between pydantic 1 and 2 for parsing an instance + of a BaseModel from varied""" + try: + # pydantic > 2 + return model.model_validate(contents) + except AttributeError: + # pydantic < 2 + return model.parse_obj(contents) + + +def model_dump(model: BaseModel) -> Dict[str, Any]: + """Compatibility layer between pydantic 1 and 2 for dumping a model's contents as a dict""" + try: + # pydantic > 2 + return model.model_dump() + except AttributeError: + # pydantic < 2 + return model.dict() + + +def model_dump_json(model: BaseModel) -> str: + """Compatibility layer between pydantic 1 and 2 for dumping a model's contents as json""" + try: + # pydantic > 2 + return model.model_dump_json() + except AttributeError: + # pydantic < 2 + return model.json() + def parse_auth_response(data: Any) -> AuthResponse: session: Union[Session, None] = None @@ -27,9 +61,9 @@ def parse_auth_response(data: Any) -> AuthResponse: and data["refresh_token"] and data["expires_in"] ): - session = Session.model_validate(data) + session = model_validate(Session, data) user_data = data.get("user", data) - user = User.model_validate(user_data) if user_data else None + user = model_validate(User, user_data) if user_data else None return AuthResponse(session=session, user=user) @@ -41,8 +75,8 @@ def parse_link_response(data: Any) -> GenerateLinkResponse: redirect_to=data.get("redirect_to"), verification_type=data.get("verification_type"), ) - user = User.model_validate( - {k: v for k, v in data.items() if k not in properties.model_dump()} + user = model_validate( + User, {k: v for k, v in data.items() if k not in model_dump(properties)} ) return GenerateLinkResponse(properties=properties, user=user) @@ -50,7 +84,7 @@ def parse_link_response(data: Any) -> GenerateLinkResponse: def parse_user_response(data: Any) -> UserResponse: if "user" not in data: data = {"user": data} - return UserResponse.model_validate(data) + return model_validate(UserResponse, data) def get_error_message(error: Any) -> str: diff --git a/gotrue/types.py b/gotrue/types.py index 67f73b9a..8defdc66 100644 --- a/gotrue/types.py +++ b/gotrue/types.py @@ -4,7 +4,19 @@ from time import time from typing import Any, Callable, Dict, List, Union -from pydantic import BaseModel, model_validator +from pydantic import BaseModel + +try: + # > 2 + from pydantic import model_validator + + model_validator_v1_v2_compat = model_validator(mode="before") +except ImportError: + # < 2 + from pydantic import root_validator + + model_validator_v1_v2_compat = root_validator + from typing_extensions import Literal, NotRequired, TypedDict Provider = Literal[ @@ -106,7 +118,7 @@ class Session(BaseModel): token_type: str user: User - @model_validator(mode="before") + @model_validator_v1_v2_compat def validator(cls, values: dict) -> dict: expires_in = values.get("expires_in") if expires_in and not values.get("expires_at"): @@ -615,22 +627,30 @@ class DecodedJWTDict(TypedDict): amr: NotRequired[Union[List[AMREntry], None]] -AMREntry.model_rebuild() -AuthResponse.model_rebuild() -OAuthResponse.model_rebuild() -UserResponse.model_rebuild() -Session.model_rebuild() -UserIdentity.model_rebuild() -Factor.model_rebuild() -User.model_rebuild() -Subscription.model_rebuild() -AuthMFAVerifyResponse.model_rebuild() -AuthMFAEnrollResponseTotp.model_rebuild() -AuthMFAEnrollResponse.model_rebuild() -AuthMFAUnenrollResponse.model_rebuild() -AuthMFAChallengeResponse.model_rebuild() -AuthMFAListFactorsResponse.model_rebuild() -AuthMFAGetAuthenticatorAssuranceLevelResponse.model_rebuild() -AuthMFAAdminDeleteFactorResponse.model_rebuild() -AuthMFAAdminListFactorsResponse.model_rebuild() -GenerateLinkProperties.model_rebuild() +for model in [ + AMREntry, + AuthResponse, + OAuthResponse, + UserResponse, + Session, + UserIdentity, + Factor, + User, + Subscription, + AuthMFAVerifyResponse, + AuthMFAEnrollResponseTotp, + AuthMFAEnrollResponse, + AuthMFAUnenrollResponse, + AuthMFAChallengeResponse, + AuthMFAListFactorsResponse, + AuthMFAGetAuthenticatorAssuranceLevelResponse, + AuthMFAAdminDeleteFactorResponse, + AuthMFAAdminListFactorsResponse, + GenerateLinkProperties, +]: + try: + # pydantic > 2 + model.rebuild_model() + except AttributeError: + # pydantic < 2 + model.update_forward_refs() diff --git a/pyproject.toml b/pyproject.toml index be54f59d..fc96636f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.8" httpx = ">=0.23,<0.25" -pydantic = "^2.1.0" +pydantic = ">=1.10,<3" [tool.poetry.dev-dependencies] pytest = "^7.3.1"