From 2454694de330f2e986f981397d7cef90393d573e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Mon, 29 Apr 2024 15:11:02 -0700 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20types=20to=20pr?= =?UTF-8?q?operly=20support=20Pydantic=202.7=20(#913)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 4 +++- sqlmodel/main.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 072d2b0f58..72ec8330fd 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -18,11 +18,13 @@ Union, ) -from pydantic import VERSION as PYDANTIC_VERSION +from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo from typing_extensions import get_args, get_origin +# Reassign variable to make it reexported for mypy +PYDANTIC_VERSION = P_VERSION IS_PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 9e8330d69d..a16428b192 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -6,6 +6,7 @@ from enum import Enum from pathlib import Path from typing import ( + TYPE_CHECKING, AbstractSet, Any, Callable, @@ -55,6 +56,7 @@ from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, + PYDANTIC_VERSION, BaseConfig, ModelField, ModelMetaclass, @@ -80,6 +82,12 @@ ) from .sql.sqltypes import GUID, AutoString +if TYPE_CHECKING: + from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass + from pydantic._internal._repr import Representation as Representation + from pydantic_core import PydanticUndefined as Undefined + from pydantic_core import PydanticUndefinedType as UndefinedType + _T = TypeVar("_T") NoArgAnyCallable = Callable[[], Any] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any], None] @@ -764,13 +772,22 @@ def model_dump( mode: Union[Literal["json", "python"], str] = "python", include: IncEx = None, exclude: IncEx = None, + context: Union[Dict[str, Any], None] = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool = True, + warnings: Union[bool, Literal["none", "warn", "error"]] = True, + serialize_as_any: bool = False, ) -> Dict[str, Any]: + if PYDANTIC_VERSION >= "2.7.0": + extra_kwargs: Dict[str, Any] = { + "context": context, + "serialize_as_any": serialize_as_any, + } + else: + extra_kwargs = {} if IS_PYDANTIC_V2: return super().model_dump( mode=mode, @@ -782,6 +799,7 @@ def model_dump( exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, + **extra_kwargs, ) else: return super().dict(