Skip to content

Commit

Permalink
✨ Fully merge Pydantic Field with SQLAlchemy Column constructor;
Browse files Browse the repository at this point in the history
allow passing all `Column` arguments directly to `Field`;
make `default` a keyword-only argument for `Field`
  • Loading branch information
daniil-berg committed Sep 4, 2022
1 parent 75ce455 commit c93b42e
Showing 1 changed file with 226 additions and 92 deletions.
318 changes: 226 additions & 92 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
import warnings
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand Down Expand Up @@ -38,8 +39,10 @@
from sqlalchemy.orm.attributes import set_attribute
from sqlalchemy.orm.decl_api import DeclarativeMeta
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import FetchedValue, MetaData, SchemaItem
from sqlalchemy.sql.sqltypes import LargeBinary, Time
from sqlalchemy.sql.type_api import TypeEngine

from .sql.sqltypes import GUID, AutoString

Expand All @@ -57,35 +60,94 @@ def __dataclass_transform__(


class FieldInfo(PydanticFieldInfo):
def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
primary_key = kwargs.pop("primary_key", False)
nullable = kwargs.pop("nullable", Undefined)
foreign_key = kwargs.pop("foreign_key", Undefined)
unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
if sa_column is not Undefined:
if sa_column_args is not Undefined:
raise RuntimeError(
"Passing sa_column_args is not supported when "
"also passing a sa_column"
)
if sa_column_kwargs is not Undefined:
raise RuntimeError(
"Passing sa_column_kwargs is not supported when "
"also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.unique = unique
self.index = index
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs

# In addition to the `PydanticFieldInfo` slots, set slots corresponding to parameters for the SQLAlchemy
# [Column](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column),
# along with any custom additions:
__slots__ = (
"name",
"type_",
"args",
"autoincrement",
# `default` omitted because that slot is defined on the base class
"doc",
"key",
"index",
"info",
"nullable",
"onupdate",
"primary_key",
"server_default",
"server_onupdate",
"quote",
"unique",
"system",
"comment",
"foreign_key", # custom parameter for easier foreign key setting
# For backwards compatibility: (!?)
"sa_column",
"sa_column_args",
"sa_column_kwargs",
)

# Defined here for static type checkers:
name: Union[str, UndefinedType]
type_: Union[TypeEngine, UndefinedType] # type: ignore[type-arg]
args: Sequence[SchemaItem]
autoincrement: Union[bool, str]
doc: Optional[str]
key: Union[str, UndefinedType]
index: Optional[bool]
info: Union[Dict[str, Any], UndefinedType]
nullable: Union[bool, UndefinedType]
onupdate: Any
primary_key: bool
server_default: Union[FetchedValue, str, TextClause, None]
server_onupdate: Optional[FetchedValue]
quote: Union[bool, None, UndefinedType]
unique: Optional[bool]
system: bool
comment: Optional[str]

foreign_key: Optional[str]

sa_column: Union[Column, UndefinedType] # type: ignore[type-arg]
sa_column_args: Sequence[Any]
sa_column_kwargs: Mapping[str, Any]

def __init__(self, **kwargs: Any) -> None:
# Split off all keyword-arguments corresponding to our new additional attributes:
new_kwargs = {param: kwargs.pop(param, Undefined) for param in self.__slots__}
# Pass the rest of the keyword-arguments to the Pydantic `FieldInfo.__init__`:
super().__init__(**kwargs)
# Set the other keyword-arguments as instance attributes:
for param, value in new_kwargs.items():
setattr(self, param, value)

def get_defined_column_kwargs(self) -> Dict[str, Any]:
"""
Returns a dictionary of keyword arguments for the SQLAlchemy `Column.__init__` method
derived from the corresponding attributes of the `FieldInfo` instance,
omitting all those that have been left undefined.
"""
special = {
"args",
"foreign_key",
"sa_column",
"sa_column_args",
"sa_column_kwargs",
}
kwargs = {}
for key in self.__slots__:
if key in special:
continue
value = getattr(self, key, Undefined)
if value is not Undefined:
kwargs[key] = value
default = get_field_info_default(self)
if default is not Undefined:
kwargs["default"] = default
return kwargs


class RelationshipInfo(Representation):
Expand Down Expand Up @@ -117,8 +179,9 @@ def __init__(


def Field(
default: Any = Undefined,
*,
*args: SchemaItem, # positional arguments for SQLAlchemy `Column.__init__`
default: Any = Undefined, # meaningful for both Pydantic and SQLAlchemy
# The following are specific to Pydantic:
default_factory: Optional[NoArgAnyCallable] = None,
alias: Optional[str] = None,
title: Optional[str] = None,
Expand All @@ -141,19 +204,78 @@ def Field(
max_length: Optional[int] = None,
allow_mutation: bool = True,
regex: Optional[str] = None,
# The following are specific to SQLAlchemy:
name: Optional[str] = None,
type_: Union[TypeEngine, UndefinedType] = Undefined, # type: ignore[type-arg]
autoincrement: Union[bool, str] = "auto",
doc: Optional[str] = None,
key: Union[str, UndefinedType] = Undefined, # `Column` default is `name`
index: Optional[bool] = None,
info: Union[Dict[str, Any], UndefinedType] = Undefined, # `Column` default is `{}`
nullable: Union[
bool, UndefinedType
] = Undefined, # `Column` default depends on `primary_key`
onupdate: Any = None,
primary_key: bool = False,
foreign_key: Optional[Any] = None,
unique: bool = False,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
server_default: Union[FetchedValue, str, TextClause, None] = None,
server_onupdate: Optional[FetchedValue] = None,
quote: Union[
bool, None, UndefinedType
] = Undefined, # `Column` default not (fully) defined
unique: Optional[bool] = None,
system: bool = False,
comment: Optional[str] = None,
foreign_key: Optional[str] = None,
# For backwards compatibility: (!?)
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore[type-arg]
sa_column_args: Sequence[Any] = (),
sa_column_kwargs: Optional[Mapping[str, Any]] = None,
# Extra:
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
) -> FieldInfo:
"""
Constructor for explicitly defining the attributes of a model field.
The resulting field information is used both for Pydantic model validation **and** for SQLAlchemy column definition.
The following parameters are passed to initialize the Pydantic `FieldInfo`
(see [`Field` docs](https://pydantic-docs.helpmanual.io/usage/schema/#field-customization)):
`default`, `default_factory`, `alias`, `title`, `description`, `exclude`, `include`, `const`, `gt`, `ge`,
`lt`, `le`, `multiple_of`, `min_items`, `max_items`, `min_length`, `max_length`, `allow_mutation`, `regex`.
These parameters are passed to initialize the SQLAlchemy
[`Column`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column):
`*args`, `name`, `type_`, `autoincrement`, `doc`, `key`, `index`, `info`, `nullable`, `onupdate`, `primary_key`,
`server_default`, `server_onupdate`, `quote`, `unique`, `system`, `comment`.
If provided, the `default_factory` argument is passed as `default` to the `Column` constructor;
otherwise, if the `default` argument is provided, it is passed to the `Column` constructor.
Note:
The SQLAlchemy `Column` default for `type_` is actually `None`, but it makes more sense to leave it undefined,
unless an argument is passed explicitly. If someone explicitly wants to pass `None` to set the `NullType` for
whatever reason, they will be able to do so.
(see [`type_`](https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.Column.params.type_))
"""
current_schema_extra = schema_extra or {}
# For backwards compatibility: (!?)
if sa_column is not Undefined:
warnings.warn(
"Specifying `sa_column` overrides all other column arguments",
DeprecationWarning,
)
if sa_column_args != ():
warnings.warn(
"Instead of `sa_column_args` use positional arguments",
DeprecationWarning,
)
if sa_column_kwargs is not None:
warnings.warn(
"`sa_column_kwargs` takes precedence over other keyword-arguments",
DeprecationWarning,
)
field_info = FieldInfo(
default,
default=default,
default_factory=default_factory,
alias=alias,
title=title,
Expand All @@ -172,14 +294,27 @@ def Field(
max_length=max_length,
allow_mutation=allow_mutation,
regex=regex,
name=name,
type_=type_,
args=args,
autoincrement=autoincrement,
doc=doc,
key=key,
index=index,
info=info,
nullable=nullable,
onupdate=onupdate,
primary_key=primary_key,
foreign_key=foreign_key,
server_default=server_default,
server_onupdate=server_onupdate,
quote=quote,
unique=unique,
nullable=nullable,
index=index,
system=system,
comment=comment,
foreign_key=foreign_key,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
sa_column_kwargs=sa_column_kwargs or {},
**current_schema_extra,
)
field_info._validate()
Expand Down Expand Up @@ -414,47 +549,48 @@ def get_sqlachemy_type(field: ModelField) -> Any:
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")


def get_column_from_field(field: ModelField) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
sa_type = get_sqlachemy_type(field)
primary_key = getattr(field.field_info, "primary_key", False)
index = getattr(field.field_info, "index", Undefined)
if index is Undefined:
index = False
nullable = not primary_key and _is_field_noneable(field)
# Override derived nullability if the nullable property is set explicitly
# on the field
if hasattr(field.field_info, "nullable"):
field_nullable = getattr(field.field_info, "nullable")
if field_nullable != Undefined:
nullable = field_nullable
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
unique = getattr(field.field_info, "unique", False)
if foreign_key:
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
"index": index,
"unique": unique,
}
sa_default = Undefined
if field.field_info.default_factory:
sa_default = field.field_info.default_factory
elif field.field_info.default is not Undefined:
sa_default = field.field_info.default
if sa_default is not Undefined:
kwargs["default"] = sa_default
sa_column_args = getattr(field.field_info, "sa_column_args", Undefined)
if sa_column_args is not Undefined:
args.extend(list(cast(Sequence[Any], sa_column_args)))
sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined)
if sa_column_kwargs is not Undefined:
kwargs.update(cast(Dict[Any, Any], sa_column_kwargs))
return Column(sa_type, *args, **kwargs) # type: ignore
def get_field_info_default(info: PydanticFieldInfo) -> Any:
"""Returns the `default_factory` if set, otherwise the `default` value."""
return info.default_factory if info.default_factory is not None else info.default


def get_column_from_pydantic_field(field: ModelField) -> Column: # type: ignore[type-arg]
"""Returns an SQLAlchemy `Column` instance derived from a regular Pydantic `ModelField`."""
kwargs = {"type_": get_sqlachemy_type(field), "nullable": _is_field_noneable(field)}
default = get_field_info_default(field.field_info)
if default is not Undefined:
kwargs["default"] = default
return Column(**kwargs)


def get_column_from_field(field: ModelField) -> Column: # type: ignore[type-arg]
"""Returns an SQLAlchemy `Column` instance derived from an SQLModel field."""
if not isinstance(field.field_info, FieldInfo): # must be regular `PydanticFieldInfo`
return get_column_from_pydantic_field(field)
# We are dealing with the customized `FieldInfo` object:
field_info: FieldInfo = field.field_info
# The `sa_column` argument trumps everything: (for backwards compatibility)
if isinstance(field_info.sa_column, Column):
return field_info.sa_column
args: List[SchemaItem] = []
kwargs = field_info.get_defined_column_kwargs()
# Only if no column type was explicitly defined, do we derive it here:
kwargs.setdefault("type_", get_sqlachemy_type(field))
# Only if nullability was not defined, do we infer it here:
kwargs.setdefault(
"nullable", not kwargs.get("primary_key", False) and _is_field_noneable(field)
)
# If a foreign key reference was explicitly named, construct the schema item here,
# and make it the first positional argument for the `Column`:
if field_info.foreign_key:
args.append(ForeignKey(field_info.foreign_key))
# All other positional column arguments are appended:
args.extend(field_info.args)
# Append `sa_column_args`: (for backwards compatibility)
args.extend(field_info.sa_column_args)
# Finally, let the `sa_column_kwargs` take precedence: (for backwards compatibility)
kwargs.update(field_info.sa_column_kwargs)
return Column(*args, **kwargs)


class_registry = weakref.WeakValueDictionary() # type: ignore
Expand Down Expand Up @@ -647,9 +783,7 @@ def __tablename__(cls) -> str:


def _is_field_noneable(field: ModelField) -> bool:
if not field.required:
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
return field.allow_none and (
field.shape != SHAPE_SINGLETON or not field.sub_fields
)
return False
if field.required:
return False
# Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947)
return field.allow_none and (field.shape != SHAPE_SINGLETON or not field.sub_fields)

0 comments on commit c93b42e

Please sign in to comment.