Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade gotrue-py to pydantic > 2.1.x #286

Merged
merged 7 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gotrue/_async/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import parse_obj_as
from pydantic import TypeAdapter

from ..exceptions import APIError
from ..helpers import check_response, encode_uri_component
Expand Down Expand Up @@ -94,7 +94,7 @@ async def list_users(self) -> List[User]:
raise APIError("No users found in response", 400)
if not isinstance(users, list):
raise APIError("Expected a list of users", 400)
return parse_obj_as(List[User], users)
return TypeAdapter(List[User]).validate_python(users)

async def sign_up_with_email(
self,
Expand Down
4 changes: 2 additions & 2 deletions gotrue/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ async def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.parse_obj(session_raw)
session = Session.model_validate(session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -628,7 +628,7 @@ async def _save_session(self, *, session: Session) -> None:
await self._persist_session(session=session)

async def _persist_session(self, *, session: Session) -> None:
data = {"session": session.dict(), "expires_at": session.expires_at}
data = {"session": session.model_dump(), "expires_at": session.expires_at}
await self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

async def _remove_session(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions gotrue/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def list_users(self) -> List[User]:
return await self._request(
"GET",
"admin/users",
xform=lambda data: [User.parse_obj(user) for user in data["users"]]
xform=lambda data: [User.model_validate(user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +161,7 @@ async def _list_factors(
return await self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.parse_obj,
xform=AuthMFAAdminListFactorsResponse.model_validate,
)

async def _delete_factor(
Expand All @@ -171,5 +171,5 @@ async def _delete_factor(
return await self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.parse_obj,
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
)
2 changes: 1 addition & 1 deletion gotrue/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def _request(
url,
headers=headers,
params=query,
json=body.dict() if isinstance(body, BaseModel) else body,
json=body.model_dump() if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
14 changes: 7 additions & 7 deletions gotrue/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.parse_obj,
xform=AuthMFAEnrollResponse.model_validate,
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -545,7 +545,7 @@ async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeRespon
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.parse_obj,
xform=AuthMFAChallengeResponse.model_validate,
)

async def _challenge_and_verify(
Expand Down Expand Up @@ -574,9 +574,9 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.parse_obj,
xform=AuthMFAVerifyResponse.model_validate,
)
session = Session.parse_obj(response.dict())
session = Session.model_validate(response.model_dump())
await self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -589,7 +589,7 @@ async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.parse_obj,
xform=AuthMFAUnenrollResponse.model_validate,
)

async def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -751,7 +751,7 @@ async def _save_session(self, session: Session) -> None:
value = (expire_in - refresh_duration_before_expires) * 1000
await self._start_auto_refresh_token(value)
if self._persist_session and session.expires_at:
await self._storage.set_item(self._storage_key, session.json())
await self._storage.set_item(self._storage_key, session.model_dump_json())

async def _start_auto_refresh_token(self, value: float) -> None:
if self._refresh_token_timer:
Expand Down Expand Up @@ -808,7 +808,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.parse_obj(data)
return Session.model_validate(data)
except Exception:
return None

Expand Down
4 changes: 2 additions & 2 deletions gotrue/_sync/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import parse_obj_as
from pydantic import TypeAdapter

from ..exceptions import APIError
from ..helpers import check_response, encode_uri_component
Expand Down Expand Up @@ -94,7 +94,7 @@ def list_users(self) -> List[User]:
raise APIError("No users found in response", 400)
if not isinstance(users, list):
raise APIError("Expected a list of users", 400)
return parse_obj_as(List[User], users)
return TypeAdapter(List[User]).validate_python(users)

def sign_up_with_email(
self,
Expand Down
4 changes: 2 additions & 2 deletions gotrue/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _recover_common(self) -> Optional[Tuple[Session, int, int]]:
and session_raw
and isinstance(session_raw, dict)
):
session = Session.parse_obj(session_raw)
session = Session.model_validate(session_raw)
expires_at = int(expires_at_raw)
time_now = round(time())
return session, expires_at, time_now
Expand Down Expand Up @@ -620,7 +620,7 @@ def _save_session(self, *, session: Session) -> None:
self._persist_session(session=session)

def _persist_session(self, *, session: Session) -> None:
data = {"session": session.dict(), "expires_at": session.expires_at}
data = {"session": session.model_dump(), "expires_at": session.expires_at}
self.local_storage.set_item(STORAGE_KEY, dumps(data, default=str))

def _remove_session(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions gotrue/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def list_users(self) -> List[User]:
return self._request(
"GET",
"admin/users",
xform=lambda data: [User.parse_obj(user) for user in data["users"]]
xform=lambda data: [User.model_validate(user) for user in data["users"]]
if "users" in data
else [],
)
Expand Down Expand Up @@ -161,7 +161,7 @@ def _list_factors(
return self._request(
"GET",
f"admin/users/{params.get('user_id')}/factors",
xform=AuthMFAAdminListFactorsResponse.parse_obj,
xform=AuthMFAAdminListFactorsResponse.model_validate,
)

def _delete_factor(
Expand All @@ -171,5 +171,5 @@ def _delete_factor(
return self._request(
"DELETE",
f"admin/users/{params.get('user_id')}/factors/{params.get('factor_id')}",
xform=AuthMFAAdminDeleteFactorResponse.parse_obj,
xform=AuthMFAAdminDeleteFactorResponse.model_validate,
)
2 changes: 1 addition & 1 deletion gotrue/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _request(
url,
headers=headers,
params=query,
json=body.dict() if isinstance(body, BaseModel) else body,
json=body.model_dump() if isinstance(body, BaseModel) else body,
)
response.raise_for_status()
result = response if no_resolve_json else response.json()
Expand Down
12 changes: 6 additions & 6 deletions gotrue/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse:
"factors",
body=params,
jwt=session.access_token,
xform=AuthMFAEnrollResponse.parse_obj,
xform=AuthMFAEnrollResponse.model_validate,
)
if response.totp.qr_code:
response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}"
Expand All @@ -543,7 +543,7 @@ def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse:
"POST",
f"factors/{params.get('factor_id')}/challenge",
jwt=session.access_token,
xform=AuthMFAChallengeResponse.parse_obj,
xform=AuthMFAChallengeResponse.model_validate,
)

def _challenge_and_verify(
Expand Down Expand Up @@ -572,9 +572,9 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse:
f"factors/{params.get('factor_id')}/verify",
body=params,
jwt=session.access_token,
xform=AuthMFAVerifyResponse.parse_obj,
xform=AuthMFAVerifyResponse.model_validate,
)
session = Session.parse_obj(response.dict())
session = Session.model_validate(response.model_dump())
self._save_session(session)
self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session)
return response
Expand All @@ -587,7 +587,7 @@ def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse:
"DELETE",
f"factors/{params.get('factor_id')}",
jwt=session.access_token,
xform=AuthMFAUnenrollResponse.parse_obj,
xform=AuthMFAUnenrollResponse.model_validate,
)

def _list_factors(self) -> AuthMFAListFactorsResponse:
Expand Down Expand Up @@ -806,7 +806,7 @@ def _get_valid_session(
except ValueError:
return None
try:
return Session.parse_obj(data)
return Session.model_validate(data)
except Exception:
return None

Expand Down
10 changes: 6 additions & 4 deletions gotrue/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def parse_auth_response(data: Any) -> AuthResponse:
and data["refresh_token"]
and data["expires_in"]
):
session = Session.parse_obj(data)
session = Session.model_validate(data)
user_data = data.get("user", data)
user = User.parse_obj(user_data) if user_data else None
user = User.model_validate(user_data) if user_data else None
return AuthResponse(session=session, user=user)


Expand All @@ -41,14 +41,16 @@ def parse_link_response(data: Any) -> GenerateLinkResponse:
redirect_to=data.get("redirect_to"),
verification_type=data.get("verification_type"),
)
user = User.parse_obj({k: v for k, v in data.items() if k not in properties.dict()})
user = User.model_validate(
{k: v for k, v in data.items() if k not in properties.model_dump()}
)
return GenerateLinkResponse(properties=properties, user=user)


def parse_user_response(data: Any) -> UserResponse:
if "user" not in data:
data = {"user": data}
return UserResponse.parse_obj(data)
return UserResponse.model_validate(data)


def get_error_message(error: Any) -> str:
Expand Down
42 changes: 21 additions & 21 deletions gotrue/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from time import time
from typing import Any, Callable, Dict, List, Union

from pydantic import BaseModel, root_validator
from pydantic import BaseModel, model_validator
from typing_extensions import Literal, NotRequired, TypedDict

Provider = Literal[
Expand Down Expand Up @@ -106,7 +106,7 @@ class Session(BaseModel):
token_type: str
user: User

@root_validator
@model_validator(mode="before")
def validator(cls, values: dict) -> dict:
expires_in = values.get("expires_in")
if expires_in and not values.get("expires_at"):
Expand Down Expand Up @@ -615,22 +615,22 @@ class DecodedJWTDict(TypedDict):
amr: NotRequired[Union[List[AMREntry], None]]


AMREntry.update_forward_refs()
AuthResponse.update_forward_refs()
OAuthResponse.update_forward_refs()
UserResponse.update_forward_refs()
Session.update_forward_refs()
UserIdentity.update_forward_refs()
Factor.update_forward_refs()
User.update_forward_refs()
Subscription.update_forward_refs()
AuthMFAVerifyResponse.update_forward_refs()
AuthMFAEnrollResponseTotp.update_forward_refs()
AuthMFAEnrollResponse.update_forward_refs()
AuthMFAUnenrollResponse.update_forward_refs()
AuthMFAChallengeResponse.update_forward_refs()
AuthMFAListFactorsResponse.update_forward_refs()
AuthMFAGetAuthenticatorAssuranceLevelResponse.update_forward_refs()
AuthMFAAdminDeleteFactorResponse.update_forward_refs()
AuthMFAAdminListFactorsResponse.update_forward_refs()
GenerateLinkProperties.update_forward_refs()
AMREntry.model_rebuild()
AuthResponse.model_rebuild()
OAuthResponse.model_rebuild()
UserResponse.model_rebuild()
Session.model_rebuild()
UserIdentity.model_rebuild()
Factor.model_rebuild()
User.model_rebuild()
Subscription.model_rebuild()
AuthMFAVerifyResponse.model_rebuild()
AuthMFAEnrollResponseTotp.model_rebuild()
AuthMFAEnrollResponse.model_rebuild()
AuthMFAUnenrollResponse.model_rebuild()
AuthMFAChallengeResponse.model_rebuild()
AuthMFAListFactorsResponse.model_rebuild()
AuthMFAGetAuthenticatorAssuranceLevelResponse.model_rebuild()
AuthMFAAdminDeleteFactorResponse.model_rebuild()
AuthMFAAdminListFactorsResponse.model_rebuild()
GenerateLinkProperties.model_rebuild()
Loading