From 968f4959b01d19aeaae7edff91ab5b7120984fe4 Mon Sep 17 00:00:00 2001 From: honglei Date: Tue, 17 Oct 2023 15:50:01 +0800 Subject: [PATCH] add old codes for sqlalchemy v1/pydantic v1 --- sqlmodel/main.py.bak | 698 +++++++++++++++++++++++++++ sqlmodel/v1/__init__.py | 139 ++++++ sqlmodel/v1/default.py | 32 ++ sqlmodel/v1/engine/__init__.py | 0 sqlmodel/v1/engine/create.py | 139 ++++++ sqlmodel/v1/engine/result.py | 79 +++ sqlmodel/v1/ext/__init__.py | 0 sqlmodel/v1/ext/asyncio/__init__.py | 0 sqlmodel/v1/ext/asyncio/session.py | 62 +++ sqlmodel/v1/main.py | 655 +++++++++++++++++++++++++ sqlmodel/v1/orm/__init__.py | 0 sqlmodel/v1/orm/session.py | 141 ++++++ sqlmodel/v1/pool/__init__.py | 1 + sqlmodel/v1/py.typed | 0 sqlmodel/v1/sql/__init__.py | 0 sqlmodel/v1/sql/base.py | 9 + sqlmodel/v1/sql/expression.py | 458 ++++++++++++++++++ sqlmodel/v1/sql/expression.py.jinja2 | 118 +++++ sqlmodel/v1/sql/sqltypes.py | 60 +++ 19 files changed, 2591 insertions(+) create mode 100644 sqlmodel/main.py.bak create mode 100644 sqlmodel/v1/__init__.py create mode 100644 sqlmodel/v1/default.py create mode 100644 sqlmodel/v1/engine/__init__.py create mode 100644 sqlmodel/v1/engine/create.py create mode 100644 sqlmodel/v1/engine/result.py create mode 100644 sqlmodel/v1/ext/__init__.py create mode 100644 sqlmodel/v1/ext/asyncio/__init__.py create mode 100644 sqlmodel/v1/ext/asyncio/session.py create mode 100644 sqlmodel/v1/main.py create mode 100644 sqlmodel/v1/orm/__init__.py create mode 100644 sqlmodel/v1/orm/session.py create mode 100644 sqlmodel/v1/pool/__init__.py create mode 100644 sqlmodel/v1/py.typed create mode 100644 sqlmodel/v1/sql/__init__.py create mode 100644 sqlmodel/v1/sql/base.py create mode 100644 sqlmodel/v1/sql/expression.py create mode 100644 sqlmodel/v1/sql/expression.py.jinja2 create mode 100644 sqlmodel/v1/sql/sqltypes.py diff --git a/sqlmodel/main.py.bak b/sqlmodel/main.py.bak new file mode 100644 index 0000000000..515ef8c323 --- /dev/null +++ b/sqlmodel/main.py.bak @@ -0,0 +1,698 @@ +from __future__ import annotations + +import ipaddress +import sys +import types +import uuid +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path +from typing import ( + AbstractSet, + Any, + Callable, + ClassVar, + Dict, + ForwardRef, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +import pydantic +from pydantic import BaseModel, EmailStr, NameEmail, ImportString +from pydantic._internal._fields import PydanticGeneralMetadata +from pydantic._internal._model_construction import ModelMetaclass +from pydantic._internal._repr import Representation +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic_core import PydanticUndefined, PydanticUndefinedType +from sqlalchemy import Boolean, Column, Date, DateTime +from sqlalchemy import Enum as sa_Enum +from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect +from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship +from sqlalchemy.orm.attributes import set_attribute +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.orm.instrumentation import is_instrumented +from sqlalchemy.orm.properties import MappedColumn +from sqlalchemy.sql import false, true +from sqlalchemy.sql.schema import DefaultClause, MetaData +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString +from .typing import SQLModelConfig + +if sys.version_info >= (3, 8): + from typing import get_args, get_origin +else: + from typing_extensions import get_args, get_origin + +from typing_extensions import Annotated, _AnnotatedAlias + +_T = TypeVar("_T") +NoArgAnyCallable = Callable[[], Any] +NoneType = type(None) + + +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a + + +class FieldInfo(PydanticFieldInfo): + nullable: Union[bool, PydanticUndefinedType] + + def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None: + primary_key = kwargs.pop("primary_key", False) + nullable = kwargs.pop("nullable", PydanticUndefined) + foreign_key = kwargs.pop("foreign_key", PydanticUndefined) + unique = kwargs.pop("unique", False) + index = kwargs.pop("index", PydanticUndefined) + sa_column = kwargs.pop("sa_column", PydanticUndefined) + sa_column_args = kwargs.pop("sa_column_args", PydanticUndefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", PydanticUndefined) + if sa_column is not PydanticUndefined: + if sa_column_args is not PydanticUndefined: + raise RuntimeError( + "Passing sa_column_args is not supported when " + "also passing a sa_column" + ) + if sa_column_kwargs is not PydanticUndefined: + raise RuntimeError( + "Passing sa_column_kwargs is not supported when " + "also passing a sa_column" + ) + super().__init__(default=default, **kwargs) + self.primary_key = primary_key + self.nullable = nullable + self.foreign_key = foreign_key + self.unique = unique + self.index = index + self.sa_column = sa_column + self.sa_column_args = sa_column_args + self.sa_column_kwargs = sa_column_kwargs + + +class RelationshipInfo(Representation): + def __init__( + self, + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, + ) -> None: + if sa_relationship is not None: + if sa_relationship_args is not None: + raise RuntimeError( + "Passing sa_relationship_args is not supported when " + "also passing a sa_relationship" + ) + if sa_relationship_kwargs is not None: + raise RuntimeError( + "Passing sa_relationship_kwargs is not supported when " + "also passing a sa_relationship" + ) + self.back_populates = back_populates + self.link_model = link_model + self.sa_relationship = sa_relationship + self.sa_relationship_args = sa_relationship_args + self.sa_relationship_kwargs = sa_relationship_kwargs + + +def Field( + default: Any = PydanticUndefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + primary_key: bool = False, + foreign_key: Optional[Any] = None, + unique: bool = False, + nullable: Union[bool, PydanticUndefinedType] = PydanticUndefined, + index: Union[bool, PydanticUndefinedType] = PydanticUndefined, + sa_column: Union[Column, PydanticUndefinedType, Callable[[], Column]] = PydanticUndefined, # type: ignore + sa_column_args: Union[Sequence[Any], PydanticUndefinedType] = PydanticUndefined, + sa_column_kwargs: Union[ + Mapping[str, Any], PydanticUndefinedType + ] = PydanticUndefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + current_schema_extra = schema_extra or {} + if default is PydanticUndefined: + if isinstance(sa_column, types.FunctionType): # lambda + sa_column_ = sa_column() + else: + sa_column_ = sa_column + + # server_default -> default + if isinstance(sa_column_, Column) and isinstance( + sa_column_.server_default, DefaultClause + ): + default_value = sa_column_.server_default.arg + if issubclass(type(sa_column_.type), Integer) and isinstance( + default_value, str + ): + default = int(default_value) + elif issubclass(type(sa_column_.type), Boolean): + if default_value is false(): + default = False + elif default_value is true(): + default = True + elif isinstance(default_value, str): + if default_value == "1": + default = True + elif default_value == "0": + default = False + + field_info = FieldInfo( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + primary_key=primary_key, + foreign_key=foreign_key, + unique=unique, + nullable=nullable, + index=index, + sa_column=sa_column, + sa_column_args=sa_column_args, + sa_column_kwargs=sa_column_kwargs, + **current_schema_extra, + ) + return field_info + + +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty[Any]] = None, + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + relationship_info = RelationshipInfo( + back_populates=back_populates, + link_model=link_model, + sa_relationship=sa_relationship, + sa_relationship_args=sa_relationship_args, + sa_relationship_kwargs=sa_relationship_kwargs, + ) + return relationship_info + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): + __sqlmodel_relationships__: Dict[str, RelationshipInfo] + model_config: SQLModelConfig + model_fields: Dict[str, FieldInfo] + + # Replicate SQLAlchemy + def __setattr__(cls, name: str, value: Any) -> None: + if cls.model_config.get("table", False): + DeclarativeMeta.__setattr__(cls, name, value) + else: + super().__setattr__(name, value) + + def __delattr__(cls, name: str) -> None: + if cls.model_config.get("table", False): + DeclarativeMeta.__delattr__(cls, name) + else: + super().__delattr__(name) + + # From Pydantic + def __new__( + cls, + name: str, + bases: Tuple[Type[Any], ...], + class_dict: Dict[str, Any], + **kwargs: Any, + ) -> Any: + relationships: Dict[str, RelationshipInfo] = {} + dict_for_pydantic = {} + original_annotations = class_dict.get("__annotations__", {}) + pydantic_annotations = {} + relationship_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, RelationshipInfo): + relationships[k] = v + else: + dict_for_pydantic[k] = v + for k, v in original_annotations.items(): + if k in relationships: + relationship_annotations[k] = v + else: + pydantic_annotations[k] = v + dict_used = { + **dict_for_pydantic, + "__weakref__": None, + "__sqlmodel_relationships__": relationships, + "__annotations__": pydantic_annotations, + } + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: Set[str] = { + key + for key in dir(SQLModelConfig) + if not ( + key.startswith("__") and key.endswith("__") + ) # skip dunder methods and attributes + } + pydantic_kwargs = kwargs.copy() + config_kwargs = { + key: pydantic_kwargs.pop(key) + for key in pydantic_kwargs.keys() & allowed_config_kwargs + } + config_table = getattr( + class_dict.get("Config", object()), "table", False + ) or kwargs.get("table", False) + # If we have a table, we need to have defaults for all fields + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything + if config_table is True: + for key in pydantic_annotations.keys(): + value = dict_used.get(key, PydanticUndefined) + if value is PydanticUndefined: + dict_used[key] = None + elif isinstance(value, FieldInfo): + if ( + value.default in (PydanticUndefined, Ellipsis) + ) and value.default_factory is None: + value.original_default = ( + value.default + ) # So we can check for nullable + value.default = None + + new_cls: Type["SQLModelMetaclass"] = super().__new__( + cls, name, bases, dict_used, **config_kwargs + ) + new_cls.__annotations__ = { + **relationship_annotations, + **pydantic_annotations, + **new_cls.__annotations__, + } + + def get_config(name: str) -> Any: + config_class_value = new_cls.model_config.get(name, PydanticUndefined) + if config_class_value is not PydanticUndefined: + return config_class_value + kwarg_value = kwargs.get(name, PydanticUndefined) + if kwarg_value is not PydanticUndefined: + return kwarg_value + return PydanticUndefined + + config_table = get_config("table") + if config_table is True: + # If it was passed by kwargs, ensure it's also set in config + new_cls.model_config["table"] = config_table + for k, v in new_cls.model_fields.items(): + col = get_column_from_field(v) + setattr(new_cls, k, col) + # Set a config flag to tell FastAPI that this should be read with a field + # in orm_mode instead of preemptively converting it to a dict. + # This could be done by reading new_cls.model_config['table'] in FastAPI, but + # that's very specific about SQLModel, so let's have another config that + # other future tools based on Pydantic can use. + new_cls.model_config["read_from_attributes"] = True + + config_registry = get_config("registry") + if config_registry is not PydanticUndefined: + config_registry = cast(registry, config_registry) + # If it was passed by kwargs, ensure it's also set in config + new_cls.model_config["registry"] = config_table + setattr(new_cls, "_sa_registry", config_registry) + setattr(new_cls, "metadata", config_registry.metadata) + setattr(new_cls, "__abstract__", True) + return new_cls + + # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models + def __init__( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any + ) -> None: + # Only one of the base classes (or the current one) should be a table model + # this allows FastAPI cloning a SQLModel for the response_model without + # trying to create a new SQLAlchemy, for a new table, with the same name, that + # triggers an error + base_is_table = False + for base in bases: + config = getattr(base, "model_config") + if config and getattr(config, "table", False): + base_is_table = True + break + if cls.model_config.get("table", False) and not base_is_table: + dict_used = dict_.copy() + for field_name, field_value in cls.model_fields.items(): + dict_used[field_name] = get_column_from_field(field_value) + for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): + if rel_info.sa_relationship: + # There's a SQLAlchemy relationship declared, that takes precedence + # over anything else, use that and continue with the next attribute + dict_used[rel_name] = rel_info.sa_relationship + continue + ann = cls.__annotations__[rel_name] + relationship_to = get_origin(ann) + # Direct relationships (e.g. 'Team' or Team) have None as an origin + if relationship_to is None: + relationship_to = ann + # If Union (e.g. Optional), get the real field + elif relationship_to is Union: + relationship_to = get_args(ann)[0] + # If a list, then also get the real field + elif relationship_to is list: + relationship_to = get_args(ann)[0] + if isinstance(relationship_to, ForwardRef): + relationship_to = relationship_to.__forward_arg__ + rel_kwargs: Dict[str, Any] = {} + if rel_info.back_populates: + rel_kwargs["back_populates"] = rel_info.back_populates + if rel_info.link_model: + ins = inspect(rel_info.link_model) + local_table = getattr(ins, "local_table") + if local_table is None: + raise RuntimeError( + "Couldn't find the secondary table for " + f"model {rel_info.link_model}" + ) + rel_kwargs["secondary"] = local_table + rel_args: List[Any] = [] + if rel_info.sa_relationship_args: + rel_args.extend(rel_info.sa_relationship_args) + if rel_info.sa_relationship_kwargs: + rel_kwargs.update(rel_info.sa_relationship_kwargs) + rel_value: RelationshipProperty[Any] = relationship( + relationship_to, *rel_args, **rel_kwargs + ) + dict_used[rel_name] = rel_value + setattr(cls, rel_name, rel_value) # Fix #315 + DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) + else: + ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + +def _is_optional_or_union(type_: Optional[type]) -> bool: + if sys.version_info >= (3, 10): + return get_origin(type_) in (types.UnionType, Union) + else: + return get_origin(type_) is Union + + +def get_sqlalchemy_type(field: FieldInfo) -> Any: + type_: Optional[type] | _AnnotatedAlias = field.annotation + + # Resolve Optional/Union fields + + if type_ is not None and _is_optional_or_union(type_): + bases = get_args(type_) + if len(bases) > 2: + raise RuntimeError( + "Cannot have a (non-optional) union as a SQL alchemy field" + ) + type_ = bases[0] + # Resolve Annoted fields, + # like typing.Annotated[pydantic_core._pydantic_core.Url, + # UrlConstraints(max_length=512, + # allowed_schemes=['smb', 'ftp', 'file']) ] + if type_ is pydantic.AnyUrl: + if field.metadata: + meta = field.metadata[0] + return AutoString(length=meta.max_length) + else: + return AutoString + + org_type = get_origin(type_) + if org_type is Annotated: + type2 = get_args(type_)[0] + if type2 is pydantic.AnyUrl: + meta = get_args(type_)[1] + return AutoString(length=meta.max_length) + elif org_type is pydantic.AnyUrl and type(type_) is _AnnotatedAlias: + return AutoString(type_.__metadata__[0].max_length) + + # The 3rd is PydanticGeneralMetadata + metadata = _get_field_metadata(field) + if type_ is None: + raise ValueError("Missing field type") + if issubclass(type_, str) or type_ in (EmailStr, NameEmail, ImportString): + max_length = getattr(metadata, "max_length", None) + if max_length: + return AutoString(length=max_length) + return AutoString + if issubclass(type_, float): + return Float + if issubclass(type_, bool): + return Boolean + if issubclass(type_, int): + return Integer + if issubclass(type_, datetime): + return DateTime + if issubclass(type_, date): + return Date + if issubclass(type_, timedelta): + return Interval + if issubclass(type_, time): + return Time + if issubclass(type_, Enum): + return sa_Enum(type_) + if issubclass(type_, bytes): + return LargeBinary + if issubclass(type_, Decimal): + return Numeric( + precision=getattr(metadata, "max_digits", None), + scale=getattr(metadata, "decimal_places", None), + ) + if issubclass(type_, ipaddress.IPv4Address): + return AutoString + if issubclass(type_, ipaddress.IPv4Network): + return AutoString + if issubclass(type_, ipaddress.IPv6Address): + return AutoString + if issubclass(type_, ipaddress.IPv6Network): + return AutoString + if issubclass(type_, Path): + return AutoString + if issubclass(type_, uuid.UUID): + return GUID + raise ValueError(f"The field {field.title} has no matching SQLAlchemy type") + + +def get_column_from_field(field: FieldInfo) -> Column: # type: ignore + """ + sa_column > field attributes > annotation info + """ + sa_column = getattr(field, "sa_column", PydanticUndefined) + if isinstance(sa_column, Column): + return sa_column + if isinstance(sa_column, MappedColumn): + return sa_column.column + if isinstance(sa_column, types.FunctionType): + col = sa_column() + assert isinstance(col, Column) + return col + sa_type = get_sqlalchemy_type(field) + primary_key = getattr(field, "primary_key", False) + index = getattr(field, "index", PydanticUndefined) + if index is PydanticUndefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + if hasattr(field, "nullable"): + field_nullable = getattr(field, "nullable") + if field_nullable != PydanticUndefined: + nullable = field_nullable + args = [] + foreign_key = getattr(field, "foreign_key", None) + unique = getattr(field, "unique", False) + if foreign_key: + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default: Union[PydanticUndefinedType, Callable[[], Any]] = PydanticUndefined + if field.default_factory: + sa_default = field.default_factory + elif field.default is not PydanticUndefined: + sa_default = field.default + if sa_default is not PydanticUndefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field, "sa_column_args", PydanticUndefined) + if sa_column_args is not PydanticUndefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field, "sa_column_kwargs", PydanticUndefined) + if sa_column_kwargs is not PydanticUndefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore + + +class_registry = weakref.WeakValueDictionary() # type: ignore + +default_registry = registry() +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") + + +class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): + # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values + __slots__ = ("__weakref__",) + __tablename__: ClassVar[Union[str, Callable[..., str]]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]] + __name__: ClassVar[str] + metadata: ClassVar[MetaData] + __allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six + model_config = SQLModelConfig(from_attributes=True) + + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__pydantic_fields_set__", set()) + if not hasattr(new_object, "__pydantic_extra__"): + object.__setattr__(new_object, "__pydantic_extra__", None) + if not hasattr(new_object, "__pydantic_private__"): + object.__setattr__(new_object, "__pydantic_private__", None) + return new_object + + def __init__(__pydantic_self__, **data: Any) -> None: + old_dict = __pydantic_self__.__dict__.copy() + super().__init__(**data) + __pydantic_self__.__dict__ = {**old_dict, **__pydantic_self__.__dict__} + non_pydantic_keys = data.keys() - __pydantic_self__.model_fields + for key in non_pydantic_keys: + if key in __pydantic_self__.__sqlmodel_relationships__: + setattr(__pydantic_self__, key, data[key]) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"_sa_instance_state"}: + self.__dict__[name] = value + return + else: + # Set in SQLAlchemy, before Pydantic to trigger events and updates + if self.model_config.get("table", False) and is_instrumented(self, name): # type: ignore + set_attribute(self, name, value) + # Set in Pydantic model to trigger possible validation changes, only for + # non relationship values + if name not in self.__sqlmodel_relationships__: + super(SQLModel, self).__setattr__(name, value) + + def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: + # Don't show SQLAlchemy private attributes + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] + + @declared_attr # type: ignore + def __tablename__(cls) -> str: + return cls.__name__.lower() + + @classmethod + def model_validate( + cls: Type[_TSQLModel], + obj: Any, + *, + strict: Optional[bool] = None, + from_attributes: Optional[bool] = None, + context: Optional[Dict[str, Any]] = None, + ) -> _TSQLModel: + # Somehow model validate doesn't call __init__ so it would remove our init logic + validated = super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + + # remove defaults so they don't get validated + data = {} + for key, value in validated: + field = cls.model_fields.get(key) + + if field is None: + continue + + if ( + hasattr(field, "default") + and field.default is not PydanticUndefined + and value == field.default + ): + continue + + data[key] = value + + return cls(**data) + + +def _is_field_noneable(field: FieldInfo) -> bool: + if hasattr(field, "nullable") and not isinstance( + field.nullable, PydanticUndefinedType + ): + return field.nullable + if not field.is_required(): + default = getattr(field, "original_default", field.default) + if default is PydanticUndefined: + return False + if field.annotation is None or field.annotation is NoneType: + return True + if _is_optional_or_union(field.annotation): + for base in get_args(field.annotation): + if base is NoneType: + return True + + return False + return False + + +def _get_field_metadata(field: FieldInfo) -> object: + for meta in field.metadata: + if isinstance(meta, PydanticGeneralMetadata): + return meta + if isinstance(meta,MaxLen ): + return meta + return object() diff --git a/sqlmodel/v1/__init__.py b/sqlmodel/v1/__init__.py new file mode 100644 index 0000000000..720aa8c929 --- /dev/null +++ b/sqlmodel/v1/__init__.py @@ -0,0 +1,139 @@ +__version__ = "0.0.8" + +# Re-export from SQLAlchemy +from sqlalchemy.engine import create_mock_engine as create_mock_engine +from sqlalchemy.engine import engine_from_config as engine_from_config +from sqlalchemy.inspection import inspect as inspect +from sqlalchemy.schema import BLANK_SCHEMA as BLANK_SCHEMA +from sqlalchemy.schema import CheckConstraint as CheckConstraint +from sqlalchemy.schema import Column as Column +from sqlalchemy.schema import ColumnDefault as ColumnDefault +from sqlalchemy.schema import Computed as Computed +from sqlalchemy.schema import Constraint as Constraint +from sqlalchemy.schema import DDL as DDL +from sqlalchemy.schema import DefaultClause as DefaultClause +from sqlalchemy.schema import FetchedValue as FetchedValue +from sqlalchemy.schema import ForeignKey as ForeignKey +from sqlalchemy.schema import ForeignKeyConstraint as ForeignKeyConstraint +from sqlalchemy.schema import Identity as Identity +from sqlalchemy.schema import Index as Index +from sqlalchemy.schema import MetaData as MetaData +from sqlalchemy.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from sqlalchemy.schema import Sequence as Sequence +from sqlalchemy.schema import Table as Table +from sqlalchemy.schema import ThreadLocalMetaData as ThreadLocalMetaData +from sqlalchemy.schema import UniqueConstraint as UniqueConstraint +from sqlalchemy.sql import alias as alias +from sqlalchemy.sql import all_ as all_ +from sqlalchemy.sql import and_ as and_ +from sqlalchemy.sql import any_ as any_ +from sqlalchemy.sql import asc as asc +from sqlalchemy.sql import between as between +from sqlalchemy.sql import bindparam as bindparam +from sqlalchemy.sql import case as case +from sqlalchemy.sql import cast as cast +from sqlalchemy.sql import collate as collate +from sqlalchemy.sql import column as column +from sqlalchemy.sql import delete as delete +from sqlalchemy.sql import desc as desc +from sqlalchemy.sql import distinct as distinct +from sqlalchemy.sql import except_ as except_ +from sqlalchemy.sql import except_all as except_all +from sqlalchemy.sql import exists as exists +from sqlalchemy.sql import extract as extract +from sqlalchemy.sql import false as false +from sqlalchemy.sql import func as func +from sqlalchemy.sql import funcfilter as funcfilter +from sqlalchemy.sql import insert as insert +from sqlalchemy.sql import intersect as intersect +from sqlalchemy.sql import intersect_all as intersect_all +from sqlalchemy.sql import join as join +from sqlalchemy.sql import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from sqlalchemy.sql import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from sqlalchemy.sql import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from sqlalchemy.sql import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from sqlalchemy.sql import lambda_stmt as lambda_stmt +from sqlalchemy.sql import lateral as lateral +from sqlalchemy.sql import literal as literal +from sqlalchemy.sql import literal_column as literal_column +from sqlalchemy.sql import modifier as modifier +from sqlalchemy.sql import not_ as not_ +from sqlalchemy.sql import null as null +from sqlalchemy.sql import nulls_first as nulls_first +from sqlalchemy.sql import nulls_last as nulls_last +from sqlalchemy.sql import nullsfirst as nullsfirst +from sqlalchemy.sql import nullslast as nullslast +from sqlalchemy.sql import or_ as or_ +from sqlalchemy.sql import outerjoin as outerjoin +from sqlalchemy.sql import outparam as outparam +from sqlalchemy.sql import over as over +from sqlalchemy.sql import subquery as subquery +from sqlalchemy.sql import table as table +from sqlalchemy.sql import tablesample as tablesample +from sqlalchemy.sql import text as text +from sqlalchemy.sql import true as true +from sqlalchemy.sql import tuple_ as tuple_ +from sqlalchemy.sql import type_coerce as type_coerce +from sqlalchemy.sql import union as union +from sqlalchemy.sql import union_all as union_all +from sqlalchemy.sql import update as update +from sqlalchemy.sql import values as values +from sqlalchemy.sql import within_group as within_group +from sqlalchemy.types import ARRAY as ARRAY +from sqlalchemy.types import BIGINT as BIGINT +from sqlalchemy.types import BigInteger as BigInteger +from sqlalchemy.types import BINARY as BINARY +from sqlalchemy.types import BLOB as BLOB +from sqlalchemy.types import BOOLEAN as BOOLEAN +from sqlalchemy.types import Boolean as Boolean +from sqlalchemy.types import CHAR as CHAR +from sqlalchemy.types import CLOB as CLOB +from sqlalchemy.types import DATE as DATE +from sqlalchemy.types import Date as Date +from sqlalchemy.types import DATETIME as DATETIME +from sqlalchemy.types import DateTime as DateTime +from sqlalchemy.types import DECIMAL as DECIMAL +from sqlalchemy.types import Enum as Enum +from sqlalchemy.types import FLOAT as FLOAT +from sqlalchemy.types import Float as Float +from sqlalchemy.types import INT as INT +from sqlalchemy.types import INTEGER as INTEGER +from sqlalchemy.types import Integer as Integer +from sqlalchemy.types import Interval as Interval +from sqlalchemy.types import JSON as JSON +from sqlalchemy.types import LargeBinary as LargeBinary +from sqlalchemy.types import NCHAR as NCHAR +from sqlalchemy.types import NUMERIC as NUMERIC +from sqlalchemy.types import Numeric as Numeric +from sqlalchemy.types import NVARCHAR as NVARCHAR +from sqlalchemy.types import PickleType as PickleType +from sqlalchemy.types import REAL as REAL +from sqlalchemy.types import SMALLINT as SMALLINT +from sqlalchemy.types import SmallInteger as SmallInteger +from sqlalchemy.types import String as String +from sqlalchemy.types import TEXT as TEXT +from sqlalchemy.types import Text as Text +from sqlalchemy.types import TIME as TIME +from sqlalchemy.types import Time as Time +from sqlalchemy.types import TIMESTAMP as TIMESTAMP +from sqlalchemy.types import TypeDecorator as TypeDecorator +from sqlalchemy.types import Unicode as Unicode +from sqlalchemy.types import UnicodeText as UnicodeText +from sqlalchemy.types import VARBINARY as VARBINARY +from sqlalchemy.types import VARCHAR as VARCHAR + +# Extensions and modifications of SQLAlchemy in SQLModel +from .engine.create import create_engine as create_engine +from .orm.session import Session as Session +from .sql.expression import select as select +from .sql.expression import col as col +from .sql.sqltypes import AutoString as AutoString + +# Export SQLModel specifics (equivalent to Pydantic) +from .main import SQLModel as SQLModel +from .main import Field as Field +from .main import Relationship as Relationship diff --git a/sqlmodel/v1/default.py b/sqlmodel/v1/default.py new file mode 100644 index 0000000000..bb44972e24 --- /dev/null +++ b/sqlmodel/v1/default.py @@ -0,0 +1,32 @@ +from typing import Any, TypeVar + + +class _DefaultPlaceholder: + """ + You shouldn't use this class directly. + + It's used internally to recognize when a default value has been overwritten, even + if the overriden default value was truthy. + """ + + def __init__(self, value: Any): + self.value = value + + def __bool__(self) -> bool: + return bool(self.value) + + def __eq__(self, o: object) -> bool: + return isinstance(o, _DefaultPlaceholder) and o.value == self.value + + +_TDefaultType = TypeVar("_TDefaultType") + + +def Default(value: _TDefaultType) -> _TDefaultType: + """ + You shouldn't use this function directly. + + It's used internally to recognize when a default value has been overwritten, even + if the overriden default value was truthy. + """ + return _DefaultPlaceholder(value) # type: ignore diff --git a/sqlmodel/v1/engine/__init__.py b/sqlmodel/v1/engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/engine/create.py b/sqlmodel/v1/engine/create.py new file mode 100644 index 0000000000..b2d567b1b1 --- /dev/null +++ b/sqlmodel/v1/engine/create.py @@ -0,0 +1,139 @@ +import json +import sqlite3 +from typing import Any, Callable, Dict, List, Optional, Type, Union + +from sqlalchemy import create_engine as _create_engine +from sqlalchemy.engine.url import URL +from sqlalchemy.future import Engine as _FutureEngine +from sqlalchemy.pool import Pool +from typing_extensions import Literal, TypedDict + +from ..default import Default, _DefaultPlaceholder + +# Types defined in sqlalchemy2-stubs, but can't be imported, so re-define here + +_Debug = Literal["debug"] + +_IsolationLevel = Literal[ + "SERIALIZABLE", + "REPEATABLE READ", + "READ COMMITTED", + "READ UNCOMMITTED", + "AUTOCOMMIT", +] +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat"] +_ResetOnReturn = Literal["rollback", "commit"] + + +class _SQLiteConnectArgs(TypedDict, total=False): + timeout: float + detect_types: Any + isolation_level: Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]] + check_same_thread: bool + factory: Type[sqlite3.Connection] + cached_statements: int + uri: bool + + +_ConnectArgs = Union[_SQLiteConnectArgs, Dict[str, Any]] + + +# Re-define create_engine to have by default future=True, and assume that's what is used +# Also show the default values used for each parameter, but don't set them unless +# explicitly passed as arguments by the user to prevent errors. E.g. SQLite doesn't +# support pool connection arguments. +def create_engine( + url: Union[str, URL], + *, + connect_args: _ConnectArgs = Default({}), # type: ignore + echo: Union[bool, _Debug] = Default(False), + echo_pool: Union[bool, _Debug] = Default(False), + enable_from_linting: bool = Default(True), + encoding: str = Default("utf-8"), + execution_options: Dict[Any, Any] = Default({}), + future: bool = True, + hide_parameters: bool = Default(False), + implicit_returning: bool = Default(True), + isolation_level: Optional[_IsolationLevel] = Default(None), + json_deserializer: Callable[..., Any] = Default(json.loads), + json_serializer: Callable[..., Any] = Default(json.dumps), + label_length: Optional[int] = Default(None), + logging_name: Optional[str] = Default(None), + max_identifier_length: Optional[int] = Default(None), + max_overflow: int = Default(10), + module: Optional[Any] = Default(None), + paramstyle: Optional[_ParamStyle] = Default(None), + pool: Optional[Pool] = Default(None), + poolclass: Optional[Type[Pool]] = Default(None), + pool_logging_name: Optional[str] = Default(None), + pool_pre_ping: bool = Default(False), + pool_size: int = Default(5), + pool_recycle: int = Default(-1), + pool_reset_on_return: Optional[_ResetOnReturn] = Default("rollback"), + pool_timeout: float = Default(30), + pool_use_lifo: bool = Default(False), + plugins: Optional[List[str]] = Default(None), + query_cache_size: Optional[int] = Default(None), + **kwargs: Any, +) -> _FutureEngine: + current_kwargs: Dict[str, Any] = { + "future": future, + } + if not isinstance(echo, _DefaultPlaceholder): + current_kwargs["echo"] = echo + if not isinstance(echo_pool, _DefaultPlaceholder): + current_kwargs["echo_pool"] = echo_pool + if not isinstance(enable_from_linting, _DefaultPlaceholder): + current_kwargs["enable_from_linting"] = enable_from_linting + if not isinstance(connect_args, _DefaultPlaceholder): + current_kwargs["connect_args"] = connect_args + if not isinstance(encoding, _DefaultPlaceholder): + current_kwargs["encoding"] = encoding + if not isinstance(execution_options, _DefaultPlaceholder): + current_kwargs["execution_options"] = execution_options + if not isinstance(hide_parameters, _DefaultPlaceholder): + current_kwargs["hide_parameters"] = hide_parameters + if not isinstance(implicit_returning, _DefaultPlaceholder): + current_kwargs["implicit_returning"] = implicit_returning + if not isinstance(isolation_level, _DefaultPlaceholder): + current_kwargs["isolation_level"] = isolation_level + if not isinstance(json_deserializer, _DefaultPlaceholder): + current_kwargs["json_deserializer"] = json_deserializer + if not isinstance(json_serializer, _DefaultPlaceholder): + current_kwargs["json_serializer"] = json_serializer + if not isinstance(label_length, _DefaultPlaceholder): + current_kwargs["label_length"] = label_length + if not isinstance(logging_name, _DefaultPlaceholder): + current_kwargs["logging_name"] = logging_name + if not isinstance(max_identifier_length, _DefaultPlaceholder): + current_kwargs["max_identifier_length"] = max_identifier_length + if not isinstance(max_overflow, _DefaultPlaceholder): + current_kwargs["max_overflow"] = max_overflow + if not isinstance(module, _DefaultPlaceholder): + current_kwargs["module"] = module + if not isinstance(paramstyle, _DefaultPlaceholder): + current_kwargs["paramstyle"] = paramstyle + if not isinstance(pool, _DefaultPlaceholder): + current_kwargs["pool"] = pool + if not isinstance(poolclass, _DefaultPlaceholder): + current_kwargs["poolclass"] = poolclass + if not isinstance(pool_logging_name, _DefaultPlaceholder): + current_kwargs["pool_logging_name"] = pool_logging_name + if not isinstance(pool_pre_ping, _DefaultPlaceholder): + current_kwargs["pool_pre_ping"] = pool_pre_ping + if not isinstance(pool_size, _DefaultPlaceholder): + current_kwargs["pool_size"] = pool_size + if not isinstance(pool_recycle, _DefaultPlaceholder): + current_kwargs["pool_recycle"] = pool_recycle + if not isinstance(pool_reset_on_return, _DefaultPlaceholder): + current_kwargs["pool_reset_on_return"] = pool_reset_on_return + if not isinstance(pool_timeout, _DefaultPlaceholder): + current_kwargs["pool_timeout"] = pool_timeout + if not isinstance(pool_use_lifo, _DefaultPlaceholder): + current_kwargs["pool_use_lifo"] = pool_use_lifo + if not isinstance(plugins, _DefaultPlaceholder): + current_kwargs["plugins"] = plugins + if not isinstance(query_cache_size, _DefaultPlaceholder): + current_kwargs["query_cache_size"] = query_cache_size + current_kwargs.update(kwargs) + return _create_engine(url, **current_kwargs) # type: ignore diff --git a/sqlmodel/v1/engine/result.py b/sqlmodel/v1/engine/result.py new file mode 100644 index 0000000000..7a25422227 --- /dev/null +++ b/sqlmodel/v1/engine/result.py @@ -0,0 +1,79 @@ +from typing import Generic, Iterator, List, Optional, TypeVar + +from sqlalchemy.engine.result import Result as _Result +from sqlalchemy.engine.result import ScalarResult as _ScalarResult + +_T = TypeVar("_T") + + +class ScalarResult(_ScalarResult, Generic[_T]): + def all(self) -> List[_T]: + return super().all() + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: + return super().partitions(size) + + def fetchall(self) -> List[_T]: + return super().fetchall() + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: + return super().fetchmany(size) + + def __iter__(self) -> Iterator[_T]: + return super().__iter__() + + def __next__(self) -> _T: + return super().__next__() # type: ignore + + def first(self) -> Optional[_T]: + return super().first() + + def one_or_none(self) -> Optional[_T]: + return super().one_or_none() + + def one(self) -> _T: + return super().one() # type: ignore + + +class Result(_Result, Generic[_T]): + def scalars(self, index: int = 0) -> ScalarResult[_T]: + return super().scalars(index) # type: ignore + + def __iter__(self) -> Iterator[_T]: # type: ignore + return super().__iter__() # type: ignore + + def __next__(self) -> _T: # type: ignore + return super().__next__() # type: ignore + + def partitions(self, size: Optional[int] = None) -> Iterator[List[_T]]: # type: ignore + return super().partitions(size) # type: ignore + + def fetchall(self) -> List[_T]: # type: ignore + return super().fetchall() # type: ignore + + def fetchone(self) -> Optional[_T]: # type: ignore + return super().fetchone() # type: ignore + + def fetchmany(self, size: Optional[int] = None) -> List[_T]: # type: ignore + return super().fetchmany() # type: ignore + + def all(self) -> List[_T]: # type: ignore + return super().all() # type: ignore + + def first(self) -> Optional[_T]: # type: ignore + return super().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: # type: ignore + return super().one_or_none() # type: ignore + + def scalar_one(self) -> _T: + return super().scalar_one() # type: ignore + + def scalar_one_or_none(self) -> Optional[_T]: + return super().scalar_one_or_none() + + def one(self) -> _T: # type: ignore + return super().one() # type: ignore + + def scalar(self) -> Optional[_T]: + return super().scalar() diff --git a/sqlmodel/v1/ext/__init__.py b/sqlmodel/v1/ext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/ext/asyncio/__init__.py b/sqlmodel/v1/ext/asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/ext/asyncio/session.py b/sqlmodel/v1/ext/asyncio/session.py new file mode 100644 index 0000000000..80267b25e5 --- /dev/null +++ b/sqlmodel/v1/ext/asyncio/session.py @@ -0,0 +1,62 @@ +from typing import Any, Mapping, Optional, Sequence, TypeVar, Union + +from sqlalchemy import util +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.ext.asyncio import engine +from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.util.concurrency import greenlet_spawn +from sqlmodel.sql.base import Executable + +from ...engine.result import ScalarResult +from ...orm.session import Session +from ...sql.expression import Select + +_T = TypeVar("_T") + + +class AsyncSession(_AsyncSession): + sync_session: Session + + def __init__( + self, + bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, + binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, + **kw: Any, + ): + # All the same code of the original AsyncSession + kw["future"] = True + if bind: + self.bind = bind + bind = engine._get_sync_engine_or_connection(bind) # type: ignore + + if binds: + self.binds = binds + binds = { + key: engine._get_sync_engine_or_connection(b) # type: ignore + for key, b in binds.items() + } + + self.sync_session = self._proxied = self._assign_proxied( # type: ignore + Session(bind=bind, binds=binds, **kw) # type: ignore + ) + + async def exec( + self, + statement: Union[Select[_T], Executable[_T]], + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[Any, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[_T]: + # TODO: the documentation says execution_options accepts a dict, but only + # util.immutabledict has the union() method. Is this a bug in SQLAlchemy? + execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore + + return await greenlet_spawn( + self.sync_session.exec, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) diff --git a/sqlmodel/v1/main.py b/sqlmodel/v1/main.py new file mode 100644 index 0000000000..d343c698e9 --- /dev/null +++ b/sqlmodel/v1/main.py @@ -0,0 +1,655 @@ +import ipaddress +import uuid +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from pathlib import Path +from typing import ( + AbstractSet, + Any, + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from pydantic import BaseConfig, BaseModel +from pydantic.errors import ConfigError, DictError +from pydantic.fields import SHAPE_SINGLETON +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import ModelField, Undefined, UndefinedType +from pydantic.main import ModelMetaclass, validate_model +from pydantic.typing import ForwardRef, NoArgAnyCallable, resolve_annotations +from pydantic.utils import ROOT_KEY, Representation +from sqlalchemy import Boolean, Column, Date, DateTime +from sqlalchemy import Enum as sa_Enum +from sqlalchemy import Float, ForeignKey, Integer, Interval, Numeric, inspect +from sqlalchemy.orm import RelationshipProperty, declared_attr, registry, relationship +from sqlalchemy.orm.attributes import set_attribute +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.orm.instrumentation import is_instrumented +from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.sqltypes import LargeBinary, Time + +from .sql.sqltypes import GUID, AutoString + +_T = TypeVar("_T") + + +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a + + +class FieldInfo(PydanticFieldInfo): + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + primary_key = kwargs.pop("primary_key", False) + nullable = kwargs.pop("nullable", Undefined) + foreign_key = kwargs.pop("foreign_key", Undefined) + unique = kwargs.pop("unique", False) + index = kwargs.pop("index", Undefined) + sa_column = kwargs.pop("sa_column", Undefined) + sa_column_args = kwargs.pop("sa_column_args", Undefined) + sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + if sa_column is not Undefined: + if sa_column_args is not Undefined: + raise RuntimeError( + "Passing sa_column_args is not supported when " + "also passing a sa_column" + ) + if sa_column_kwargs is not Undefined: + raise RuntimeError( + "Passing sa_column_kwargs is not supported when " + "also passing a sa_column" + ) + super().__init__(default=default, **kwargs) + self.primary_key = primary_key + self.nullable = nullable + self.foreign_key = foreign_key + self.unique = unique + self.index = index + self.sa_column = sa_column + self.sa_column_args = sa_column_args + self.sa_column_kwargs = sa_column_kwargs + + +class RelationshipInfo(Representation): + def __init__( + self, + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, + ) -> None: + if sa_relationship is not None: + if sa_relationship_args is not None: + raise RuntimeError( + "Passing sa_relationship_args is not supported when " + "also passing a sa_relationship" + ) + if sa_relationship_kwargs is not None: + raise RuntimeError( + "Passing sa_relationship_kwargs is not supported when " + "also passing a sa_relationship" + ) + self.back_populates = back_populates + self.link_model = link_model + self.sa_relationship = sa_relationship + self.sa_relationship_args = sa_relationship_args + self.sa_relationship_kwargs = sa_relationship_kwargs + + +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + include: Union[ + AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any + ] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + primary_key: bool = False, + foreign_key: Optional[Any] = None, + unique: bool = False, + nullable: Union[bool, UndefinedType] = Undefined, + index: Union[bool, UndefinedType] = Undefined, + sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore + sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, + sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + schema_extra: Optional[Dict[str, Any]] = None, +) -> Any: + current_schema_extra = schema_extra or {} + field_info = FieldInfo( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + min_items=min_items, + max_items=max_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + primary_key=primary_key, + foreign_key=foreign_key, + unique=unique, + nullable=nullable, + index=index, + sa_column=sa_column, + sa_column_args=sa_column_args, + sa_column_kwargs=sa_column_kwargs, + **current_schema_extra, + ) + field_info._validate() + return field_info + + +def Relationship( + *, + back_populates: Optional[str] = None, + link_model: Optional[Any] = None, + sa_relationship: Optional[RelationshipProperty] = None, # type: ignore + sa_relationship_args: Optional[Sequence[Any]] = None, + sa_relationship_kwargs: Optional[Mapping[str, Any]] = None, +) -> Any: + relationship_info = RelationshipInfo( + back_populates=back_populates, + link_model=link_model, + sa_relationship=sa_relationship, + sa_relationship_args=sa_relationship_args, + sa_relationship_kwargs=sa_relationship_kwargs, + ) + return relationship_info + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +class SQLModelMetaclass(ModelMetaclass, DeclarativeMeta): + __sqlmodel_relationships__: Dict[str, RelationshipInfo] + __config__: Type[BaseConfig] + __fields__: Dict[str, ModelField] + + # Replicate SQLAlchemy + def __setattr__(cls, name: str, value: Any) -> None: + if getattr(cls.__config__, "table", False): + DeclarativeMeta.__setattr__(cls, name, value) + else: + super().__setattr__(name, value) + + def __delattr__(cls, name: str) -> None: + if getattr(cls.__config__, "table", False): + DeclarativeMeta.__delattr__(cls, name) + else: + super().__delattr__(name) + + # From Pydantic + def __new__( + cls, + name: str, + bases: Tuple[Type[Any], ...], + class_dict: Dict[str, Any], + **kwargs: Any, + ) -> Any: + relationships: Dict[str, RelationshipInfo] = {} + dict_for_pydantic = {} + original_annotations = resolve_annotations( + class_dict.get("__annotations__", {}), class_dict.get("__module__", None) + ) + pydantic_annotations = {} + relationship_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, RelationshipInfo): + relationships[k] = v + else: + dict_for_pydantic[k] = v + for k, v in original_annotations.items(): + if k in relationships: + relationship_annotations[k] = v + else: + pydantic_annotations[k] = v + dict_used = { + **dict_for_pydantic, + "__weakref__": None, + "__sqlmodel_relationships__": relationships, + "__annotations__": pydantic_annotations, + } + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: Set[str] = { + key + for key in dir(BaseConfig) + if not ( + key.startswith("__") and key.endswith("__") + ) # skip dunder methods and attributes + } + pydantic_kwargs = kwargs.copy() + config_kwargs = { + key: pydantic_kwargs.pop(key) + for key in pydantic_kwargs.keys() & allowed_config_kwargs + } + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls.__annotations__ = { + **relationship_annotations, + **pydantic_annotations, + **new_cls.__annotations__, + } + + def get_config(name: str) -> Any: + config_class_value = getattr(new_cls.__config__, name, Undefined) + if config_class_value is not Undefined: + return config_class_value + kwarg_value = kwargs.get(name, Undefined) + if kwarg_value is not Undefined: + return kwarg_value + return Undefined + + config_table = get_config("table") + if config_table is True: + # If it was passed by kwargs, ensure it's also set in config + new_cls.__config__.table = config_table + for k, v in new_cls.__fields__.items(): + col = get_column_from_field(v) + setattr(new_cls, k, col) + # Set a config flag to tell FastAPI that this should be read with a field + # in orm_mode instead of preemptively converting it to a dict. + # This could be done by reading new_cls.__config__.table in FastAPI, but + # that's very specific about SQLModel, so let's have another config that + # other future tools based on Pydantic can use. + new_cls.__config__.read_with_orm_mode = True + + config_registry = get_config("registry") + if config_registry is not Undefined: + config_registry = cast(registry, config_registry) + # If it was passed by kwargs, ensure it's also set in config + new_cls.__config__.registry = config_table + setattr(new_cls, "_sa_registry", config_registry) + setattr(new_cls, "metadata", config_registry.metadata) + setattr(new_cls, "__abstract__", True) + return new_cls + + # Override SQLAlchemy, allow both SQLAlchemy and plain Pydantic models + def __init__( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any + ) -> None: + # Only one of the base classes (or the current one) should be a table model + # this allows FastAPI cloning a SQLModel for the response_model without + # trying to create a new SQLAlchemy, for a new table, with the same name, that + # triggers an error + base_is_table = False + for base in bases: + config = getattr(base, "__config__") + if config and getattr(config, "table", False): + base_is_table = True + break + if getattr(cls.__config__, "table", False) and not base_is_table: + dict_used = dict_.copy() + for field_name, field_value in cls.__fields__.items(): + dict_used[field_name] = get_column_from_field(field_value) + for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): + if rel_info.sa_relationship: + # There's a SQLAlchemy relationship declared, that takes precedence + # over anything else, use that and continue with the next attribute + dict_used[rel_name] = rel_info.sa_relationship + continue + ann = cls.__annotations__[rel_name] + temp_field = ModelField.infer( + name=rel_name, + value=rel_info, + annotation=ann, + class_validators=None, + config=BaseConfig, + ) + relationship_to = temp_field.type_ + if isinstance(temp_field.type_, ForwardRef): + relationship_to = temp_field.type_.__forward_arg__ + rel_kwargs: Dict[str, Any] = {} + if rel_info.back_populates: + rel_kwargs["back_populates"] = rel_info.back_populates + if rel_info.link_model: + ins = inspect(rel_info.link_model) + local_table = getattr(ins, "local_table") + if local_table is None: + raise RuntimeError( + "Couldn't find the secondary table for " + f"model {rel_info.link_model}" + ) + rel_kwargs["secondary"] = local_table + rel_args: List[Any] = [] + if rel_info.sa_relationship_args: + rel_args.extend(rel_info.sa_relationship_args) + if rel_info.sa_relationship_kwargs: + rel_kwargs.update(rel_info.sa_relationship_kwargs) + rel_value: RelationshipProperty = relationship( # type: ignore + relationship_to, *rel_args, **rel_kwargs + ) + dict_used[rel_name] = rel_value + setattr(cls, rel_name, rel_value) # Fix #315 + DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) + else: + ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + +def get_sqlachemy_type(field: ModelField) -> Any: + if issubclass(field.type_, str): + if field.field_info.max_length: + return AutoString(length=field.field_info.max_length) + return AutoString + if issubclass(field.type_, float): + return Float + if issubclass(field.type_, bool): + return Boolean + if issubclass(field.type_, int): + return Integer + if issubclass(field.type_, datetime): + return DateTime + if issubclass(field.type_, date): + return Date + if issubclass(field.type_, timedelta): + return Interval + if issubclass(field.type_, time): + return Time + if issubclass(field.type_, Enum): + return sa_Enum(field.type_) + if issubclass(field.type_, bytes): + return LargeBinary + if issubclass(field.type_, Decimal): + return Numeric( + precision=getattr(field.type_, "max_digits", None), + scale=getattr(field.type_, "decimal_places", None), + ) + if issubclass(field.type_, ipaddress.IPv4Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv4Network): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Address): + return AutoString + if issubclass(field.type_, ipaddress.IPv6Network): + return AutoString + if issubclass(field.type_, Path): + return AutoString + if issubclass(field.type_, uuid.UUID): + return GUID + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + + +def get_column_from_field(field: ModelField) -> Column: # type: ignore + sa_column = getattr(field.field_info, "sa_column", Undefined) + if isinstance(sa_column, Column): + return sa_column + sa_type = get_sqlachemy_type(field) + primary_key = getattr(field.field_info, "primary_key", False) + index = getattr(field.field_info, "index", Undefined) + if index is Undefined: + index = False + nullable = not primary_key and _is_field_noneable(field) + # Override derived nullability if the nullable property is set explicitly + # on the field + if hasattr(field.field_info, "nullable"): + field_nullable = getattr(field.field_info, "nullable") + if field_nullable != Undefined: + nullable = field_nullable + args = [] + foreign_key = getattr(field.field_info, "foreign_key", None) + unique = getattr(field.field_info, "unique", False) + if foreign_key: + args.append(ForeignKey(foreign_key)) + kwargs = { + "primary_key": primary_key, + "nullable": nullable, + "index": index, + "unique": unique, + } + sa_default = Undefined + if field.field_info.default_factory: + sa_default = field.field_info.default_factory + elif field.field_info.default is not Undefined: + sa_default = field.field_info.default + if sa_default is not Undefined: + kwargs["default"] = sa_default + sa_column_args = getattr(field.field_info, "sa_column_args", Undefined) + if sa_column_args is not Undefined: + args.extend(list(cast(Sequence[Any], sa_column_args))) + sa_column_kwargs = getattr(field.field_info, "sa_column_kwargs", Undefined) + if sa_column_kwargs is not Undefined: + kwargs.update(cast(Dict[Any, Any], sa_column_kwargs)) + return Column(sa_type, *args, **kwargs) # type: ignore + + +class_registry = weakref.WeakValueDictionary() # type: ignore + +default_registry = registry() + + +def _value_items_is_true(v: Any) -> bool: + # Re-implement Pydantic's ValueItems.is_true() as it hasn't been released as of + # the current latest, Pydantic 1.8.2 + return v is True or v is ... + + +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") + + +class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): + # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values + __slots__ = ("__weakref__",) + __tablename__: ClassVar[Union[str, Callable[..., str]]] + __sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty]] # type: ignore + __name__: ClassVar[str] + metadata: ClassVar[MetaData] + + class Config: + orm_mode = True + + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + new_object = super().__new__(cls) + # SQLAlchemy doesn't call __init__ on the base class + # Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html + # Set __fields_set__ here, that would have been set when calling __init__ + # in the Pydantic model so that when SQLAlchemy sets attributes that are + # added (e.g. when querying from DB) to the __fields_set__, this already exists + object.__setattr__(new_object, "__fields_set__", set()) + return new_object + + def __init__(__pydantic_self__, **data: Any) -> None: + # Uses something other than `self` the first arg to allow "self" as a + # settable attribute + values, fields_set, validation_error = validate_model( + __pydantic_self__.__class__, data + ) + # Only raise errors if not a SQLModel model + if ( + not getattr(__pydantic_self__.__config__, "table", False) + and validation_error + ): + raise validation_error + # Do not set values as in Pydantic, pass them through setattr, so SQLAlchemy + # can handle them + # object.__setattr__(__pydantic_self__, '__dict__', values) + for key, value in values.items(): + setattr(__pydantic_self__, key, value) + object.__setattr__(__pydantic_self__, "__fields_set__", fields_set) + non_pydantic_keys = data.keys() - values.keys() + for key in non_pydantic_keys: + if key in __pydantic_self__.__sqlmodel_relationships__: + setattr(__pydantic_self__, key, data[key]) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"_sa_instance_state"}: + self.__dict__[name] = value + return + else: + # Set in SQLAlchemy, before Pydantic to trigger events and updates + if getattr(self.__config__, "table", False) and is_instrumented(self, name): + set_attribute(self, name, value) + # Set in Pydantic model to trigger possible validation changes, only for + # non relationship values + if name not in self.__sqlmodel_relationships__: + super().__setattr__(name, value) + + @classmethod + def from_orm( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + # Duplicated from Pydantic + if not cls.__config__.orm_mode: + raise ConfigError( + "You must have the config attribute orm_mode=True to use from_orm" + ) + obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + if not getattr(cls.__config__, "table", False): + # If not table, normal Pydantic code + m: _TSQLModel = cls.__new__(cls) + else: + # If table, create the new instance normally to make SQLAlchemy create + # the _sa_instance_state attribute + m = cls() + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + # Updated to trigger SQLAlchemy internal handling + if not getattr(cls.__config__, "table", False): + object.__setattr__(m, "__dict__", values) + else: + for key, value in values.items(): + setattr(m, key, value) + # Continue with standard Pydantic logic + object.__setattr__(m, "__fields_set__", fields_set) + m._init_private_attributes() + return m + + @classmethod + def parse_obj( + cls: Type[_TSQLModel], obj: Any, update: Optional[Dict[str, Any]] = None + ) -> _TSQLModel: + obj = cls._enforce_dict_if_root(obj) + # SQLModel, support update dict + if update is not None: + obj = {**obj, **update} + # End SQLModel support dict + return super().parse_obj(obj) + + def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]: + # Don't show SQLAlchemy private attributes + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_sa_")] + + # From Pydantic, override to enforce validation with dict + @classmethod + def validate(cls: Type[_TSQLModel], value: Any) -> _TSQLModel: + if isinstance(value, cls): + return value.copy() if cls.__config__.copy_on_model_validation else value + + value = cls._enforce_dict_if_root(value) + if isinstance(value, dict): + values, fields_set, validation_error = validate_model(cls, value) + if validation_error: + raise validation_error + model = cls(**value) + # Reset fields set, this would have been done in Pydantic in __init__ + object.__setattr__(model, "__fields_set__", fields_set) + return model + elif cls.__config__.orm_mode: + return cls.from_orm(value) + elif cls.__custom_root_type__: + return cls.parse_obj(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + # From Pydantic, override to only show keys from fields, omit SQLAlchemy attributes + def _calculate_keys( + self, + include: Optional[Mapping[Union[int, str], Any]], + exclude: Optional[Mapping[Union[int, str], Any]], + exclude_unset: bool, + update: Optional[Dict[str, Any]] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and not exclude_unset: + # Original in Pydantic: + # return None + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + return self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() + else: + # Original in Pydantic: + # keys = self.__dict__.keys() + # Updated to not return SQLAlchemy attributes + # Do not include relationships as that would easily lead to infinite + # recursion, or traversing the whole database + keys = self.__fields__.keys() # | self.__sqlmodel_relationships__.keys() + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _value_items_is_true(v)} + + return keys + + @declared_attr # type: ignore + def __tablename__(cls) -> str: + return cls.__name__.lower() + + +def _is_field_noneable(field: ModelField) -> bool: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + return field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return False diff --git a/sqlmodel/v1/orm/__init__.py b/sqlmodel/v1/orm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/orm/session.py b/sqlmodel/v1/orm/session.py new file mode 100644 index 0000000000..1692fdcbcb --- /dev/null +++ b/sqlmodel/v1/orm/session.py @@ -0,0 +1,141 @@ +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload + +from sqlalchemy import util +from sqlalchemy.orm import Query as _Query +from sqlalchemy.orm import Session as _Session +from sqlalchemy.sql.base import Executable as _Executable +from sqlmodel.sql.expression import Select, SelectOfScalar +from typing_extensions import Literal + +from ..engine.result import Result, ScalarResult +from ..sql.base import Executable + +_TSelectParam = TypeVar("_TSelectParam") + + +class Session(_Session): + @overload + def exec( + self, + statement: Select[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Result[_TSelectParam]: + ... + + @overload + def exec( + self, + statement: SelectOfScalar[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> ScalarResult[_TSelectParam]: + ... + + def exec( + self, + statement: Union[ + Select[_TSelectParam], + SelectOfScalar[_TSelectParam], + Executable[_TSelectParam], + ], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: + results = super().execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + **kw, + ) + if isinstance(statement, SelectOfScalar): + return results.scalars() # type: ignore + return results # type: ignore + + def execute( + self, + statement: _Executable, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + **kw: Any, + ) -> Result[Any]: + """ + 🚨 You probably want to use `session.exec()` instead of `session.execute()`. + + This is the original SQLAlchemy `session.execute()` method that returns objects + of type `Row`, and that you have to call `scalars()` to get the model objects. + + For example: + + ```Python + heroes = session.execute(select(Hero)).scalars().all() + ``` + + instead you could use `exec()`: + + ```Python + heroes = session.exec(select(Hero)).all() + ``` + """ + return super().execute( # type: ignore + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + **kw, + ) + + def query(self, *entities: Any, **kwargs: Any) -> "_Query[Any]": + """ + 🚨 You probably want to use `session.exec()` instead of `session.query()`. + + `session.exec()` is SQLModel's own short version with increased type + annotations. + + Or otherwise you might want to use `session.execute()` instead of + `session.query()`. + """ + return super().query(*entities, **kwargs) + + def get( + self, + entity: Type[_TSelectParam], + ident: Any, + options: Optional[Sequence[Any]] = None, + populate_existing: bool = False, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + identity_token: Optional[Any] = None, + execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, + ) -> Optional[_TSelectParam]: + return super().get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) diff --git a/sqlmodel/v1/pool/__init__.py b/sqlmodel/v1/pool/__init__.py new file mode 100644 index 0000000000..20bb952525 --- /dev/null +++ b/sqlmodel/v1/pool/__init__.py @@ -0,0 +1 @@ +from sqlalchemy.pool import StaticPool as StaticPool # noqa: F401 diff --git a/sqlmodel/v1/py.typed b/sqlmodel/v1/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/sql/__init__.py b/sqlmodel/v1/sql/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sqlmodel/v1/sql/base.py b/sqlmodel/v1/sql/base.py new file mode 100644 index 0000000000..3764a9721d --- /dev/null +++ b/sqlmodel/v1/sql/base.py @@ -0,0 +1,9 @@ +from typing import Generic, TypeVar + +from sqlalchemy.sql.base import Executable as _Executable + +_T = TypeVar("_T") + + +class Executable(_Executable, Generic[_T]): + pass diff --git a/sqlmodel/v1/sql/expression.py b/sqlmodel/v1/sql/expression.py new file mode 100644 index 0000000000..31c0bc1a1e --- /dev/null +++ b/sqlmodel/v1/sql/expression.py @@ -0,0 +1,458 @@ +# WARNING: do not modify this code, it is generated by expression.py.jinja2 + +import sys +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID + +from sqlalchemy import Column +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.expression import Select as _Select + +_TSelect = TypeVar("_TSelect") + +# Workaround Generics incompatibility in Python 3.6 +# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322 +if sys.version_info.minor >= 7: + + class Select(_Select, Generic[_TSelect]): + inherit_cache = True + + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different + # purpose. This is the same as a normal SQLAlchemy Select class where there's only one + # entity, so the result will be converted to a scalar by default. This way writing + # for loops on the results will feel natural. + class SelectOfScalar(_Select, Generic[_TSelect]): + inherit_cache = True + +else: + from typing import GenericMeta # type: ignore + + class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore + pass + + class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + inherit_cache = True + + class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + inherit_cache = True + + # Cast them for editors to work correctly, from several tricks tried, this works + # for both VS Code and PyCharm + Select = cast("Select", _Py36Select) # type: ignore + SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore + + +if TYPE_CHECKING: # pragma: no cover + from ..main import SQLModel + +# Generated TypeVars start + + +_TScalar_0 = TypeVar( + "_TScalar_0", + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_0 = TypeVar("_TModel_0", bound="SQLModel") + + +_TScalar_1 = TypeVar( + "_TScalar_1", + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_1 = TypeVar("_TModel_1", bound="SQLModel") + + +_TScalar_2 = TypeVar( + "_TScalar_2", + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_2 = TypeVar("_TModel_2", bound="SQLModel") + + +_TScalar_3 = TypeVar( + "_TScalar_3", + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_3 = TypeVar("_TModel_3", bound="SQLModel") + + +# Generated TypeVars end + + +@overload +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore + ... + + +@overload +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore + ... + + +# Generated overloads start + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TScalar_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: _TScalar_0, + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TScalar_0, _TModel_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: _TScalar_1, + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TScalar_1, _TModel_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: _TScalar_2, + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TScalar_2, _TModel_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: _TScalar_3, + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TScalar_3]]: + ... + + +@overload +def select( # type: ignore + entity_0: Type[_TModel_0], + entity_1: Type[_TModel_1], + entity_2: Type[_TModel_2], + entity_3: Type[_TModel_3], + **kw: Any, +) -> Select[Tuple[_TModel_0, _TModel_1, _TModel_2, _TModel_3]]: + ... + + +# Generated overloads end + + +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore + if len(entities) == 1: + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore + + +# TODO: add several @overload from Python types to SQLAlchemy equivalents +def col(column_expression: Any) -> ColumnClause: # type: ignore + if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): + raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") + return column_expression diff --git a/sqlmodel/v1/sql/expression.py.jinja2 b/sqlmodel/v1/sql/expression.py.jinja2 new file mode 100644 index 0000000000..51f04a215d --- /dev/null +++ b/sqlmodel/v1/sql/expression.py.jinja2 @@ -0,0 +1,118 @@ +import sys +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID + +from sqlalchemy import Column +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.expression import Select as _Select + +_TSelect = TypeVar("_TSelect") + +# Workaround Generics incompatibility in Python 3.6 +# Ref: https://github.com/python/typing/issues/449#issuecomment-316061322 +if sys.version_info.minor >= 7: + + class Select(_Select, Generic[_TSelect]): + inherit_cache = True + + # This is not comparable to sqlalchemy.sql.selectable.ScalarSelect, that has a different + # purpose. This is the same as a normal SQLAlchemy Select class where there's only one + # entity, so the result will be converted to a scalar by default. This way writing + # for loops on the results will feel natural. + class SelectOfScalar(_Select, Generic[_TSelect]): + inherit_cache = True + +else: + from typing import GenericMeta # type: ignore + + class GenericSelectMeta(GenericMeta, _Select.__class__): # type: ignore + pass + + class _Py36Select(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + inherit_cache = True + + class _Py36SelectOfScalar(_Select, Generic[_TSelect], metaclass=GenericSelectMeta): + inherit_cache = True + + # Cast them for editors to work correctly, from several tricks tried, this works + # for both VS Code and PyCharm + Select = cast("Select", _Py36Select) # type: ignore + SelectOfScalar = cast("SelectOfScalar", _Py36SelectOfScalar) # type: ignore + + +if TYPE_CHECKING: # pragma: no cover + from ..main import SQLModel + +# Generated TypeVars start + +{% for i in range(number_of_types) %} +_TScalar_{{ i }} = TypeVar( + "_TScalar_{{ i }}", + Column, # type: ignore + Sequence, # type: ignore + Mapping, # type: ignore + UUID, + datetime, + float, + int, + bool, + bytes, + str, + None, +) + +_TModel_{{ i }} = TypeVar("_TModel_{{ i }}", bound="SQLModel") + +{% endfor %} + +# Generated TypeVars end + +@overload +def select(entity_0: _TScalar_0, **kw: Any) -> SelectOfScalar[_TScalar_0]: # type: ignore + ... + + +@overload +def select(entity_0: Type[_TModel_0], **kw: Any) -> SelectOfScalar[_TModel_0]: # type: ignore + ... + + +# Generated overloads start + +{% for signature in signatures %} + +@overload +def select( # type: ignore + {% for arg in signature[0] %}{{ arg.name }}: {{ arg.annotation }}, {% endfor %}**kw: Any, + ) -> Select[Tuple[{%for ret in signature[1] %}{{ ret }} {% if not loop.last %}, {% endif %}{% endfor %}]]: + ... + +{% endfor %} + +# Generated overloads end + +def select(*entities: Any, **kw: Any) -> Union[Select, SelectOfScalar]: # type: ignore + if len(entities) == 1: + return SelectOfScalar._create(*entities, **kw) # type: ignore + return Select._create(*entities, **kw) # type: ignore + + +# TODO: add several @overload from Python types to SQLAlchemy equivalents +def col(column_expression: Any) -> ColumnClause: # type: ignore + if not isinstance(column_expression, (ColumnClause, Column, InstrumentedAttribute)): + raise RuntimeError(f"Not a SQLAlchemy column: {column_expression}") + return column_expression diff --git a/sqlmodel/v1/sql/sqltypes.py b/sqlmodel/v1/sql/sqltypes.py new file mode 100644 index 0000000000..09b8239476 --- /dev/null +++ b/sqlmodel/v1/sql/sqltypes.py @@ -0,0 +1,60 @@ +import uuid +from typing import Any, Optional, cast + +from sqlalchemy import CHAR, types +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql.type_api import TypeEngine + + +class AutoString(types.TypeDecorator): # type: ignore + + impl = types.String + cache_ok = True + mysql_default_length = 255 + + def load_dialect_impl(self, dialect: Dialect) -> "types.TypeEngine[Any]": + impl = cast(types.String, self.impl) + if impl.length is None and dialect.name == "mysql": + return dialect.type_descriptor(types.String(self.mysql_default_length)) # type: ignore + return super().load_dialect_impl(dialect) + + +# Reference form SQLAlchemy docs: https://docs.sqlalchemy.org/en/14/core/custom_types.html#backend-agnostic-guid-type +# with small modifications +class GUID(types.TypeDecorator): # type: ignore + """Platform-independent GUID type. + + Uses PostgreSQL's UUID type, otherwise uses + CHAR(32), storing as stringified hex values. + + """ + + impl = CHAR + cache_ok = True + + def load_dialect_impl(self, dialect: Dialect) -> TypeEngine: # type: ignore + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID()) # type: ignore + else: + return dialect.type_descriptor(CHAR(32)) # type: ignore + + def process_bind_param(self, value: Any, dialect: Dialect) -> Optional[str]: + if value is None: + return value + elif dialect.name == "postgresql": + return str(value) + else: + if not isinstance(value, uuid.UUID): + return uuid.UUID(value).hex + else: + # hexstring + return value.hex + + def process_result_value(self, value: Any, dialect: Dialect) -> Optional[uuid.UUID]: + if value is None: + return value + else: + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return cast(uuid.UUID, value)