Skip to content

Commit

Permalink
♻️ Refactor types to properly support Pydantic 2.7 (#913)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiangolo authored Apr 29, 2024
1 parent 6151f23 commit 2454694
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
Union,
)

from pydantic import VERSION as PYDANTIC_VERSION
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")


Expand Down
20 changes: 19 additions & 1 deletion sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from enum import Enum
from pathlib import Path
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Expand Down Expand Up @@ -55,6 +56,7 @@

from ._compat import ( # type: ignore[attr-defined]
IS_PYDANTIC_V2,
PYDANTIC_VERSION,
BaseConfig,
ModelField,
ModelMetaclass,
Expand All @@ -80,6 +82,12 @@
)
from .sql.sqltypes import GUID, AutoString

if TYPE_CHECKING:
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
from pydantic._internal._repr import Representation as Representation
from pydantic_core import PydanticUndefined as Undefined
from pydantic_core import PydanticUndefinedType as UndefinedType

_T = TypeVar("_T")
NoArgAnyCallable = Callable[[], Any]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None]
Expand Down Expand Up @@ -764,13 +772,22 @@ def model_dump(
mode: Union[Literal["json", "python"], str] = "python",
include: IncEx = None,
exclude: IncEx = None,
context: Union[Dict[str, Any], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
warnings: Union[bool, Literal["none", "warn", "error"]] = True,
serialize_as_any: bool = False,
) -> Dict[str, Any]:
if PYDANTIC_VERSION >= "2.7.0":
extra_kwargs: Dict[str, Any] = {
"context": context,
"serialize_as_any": serialize_as_any,
}
else:
extra_kwargs = {}
if IS_PYDANTIC_V2:
return super().model_dump(
mode=mode,
Expand All @@ -782,6 +799,7 @@ def model_dump(
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
**extra_kwargs,
)
else:
return super().dict(
Expand Down

0 comments on commit 2454694

Please sign in to comment.