Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add support for SQLAlchemy polymorphic models #1226

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

PaleNeutron
Copy link

@PaleNeutron PaleNeutron commented Nov 26, 2024

Introduce support for SQLAlchemy polymorphic models by adjusting field defaults and handling inheritance correctly in the SQLModel metaclass. Add tests to verify functionality with polymorphic joined and single table inheritance. Refer to #36 .

@PaleNeutron PaleNeutron marked this pull request as ready for review November 26, 2024 03:26
@PaleNeutron PaleNeutron changed the title Support SQLAlchemy polymorphic models Add support for SQLAlchemy polymorphic models Nov 26, 2024
@PaleNeutron PaleNeutron changed the title Add support for SQLAlchemy polymorphic models [feature] Add support for SQLAlchemy polymorphic models Nov 26, 2024
@mmx86
Copy link

mmx86 commented Dec 2, 2024

@tiangolo
Could you please comment on whether this request has a good chance of being merged?
My team and I, being under time constraints, are currently trying to decide whether to commit to this feature already.

sqlmodel/_compat.py Outdated Show resolved Hide resolved
Co-authored-by: John Pocock <John-P@users.noreply.github.com>
@ndeybach
Copy link

We are also exploring using SQLModel in our products. This would be quite an ease of life in how we are building our stack.

@tiangolo do you have a timeline as to when could this be merged / what needs to be done ?

@guhur
Copy link

guhur commented Jan 3, 2025

Thanks a lot for this PR! We would love to add this feature in our codebase.

Unfortunately, we could not use this PR along with a custom type.

@PaleNeutron would you mind checking this MRE?

(1) the code works fine if you comment DarkHero or if you comment my_model.

(2) however, it fails if both are in the module!

Code

import json
import typing as t

from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, TypeAdapter

# Warning: we import a deprecated class from the `pydantic` package
# See: https://github.com/pydantic/pydantic/issues/6381
from pydantic._internal._model_construction import ModelMetaclass  # noqa: PLC2701
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType
from sqlmodel import (
    JSON,
    Column,
    Field,
    Session,
    SQLModel,
    TypeDecorator,
    create_engine,
    select,
)


def pydantic_column_type(  # noqa: C901
    pydantic_type: type[t.Any],
) -> type[TypeDecorator]:
    """
    See details here:
    https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
    """
    T = t.TypeVar("T")

    class PydanticJSONType(TypeDecorator, t.Generic[T]):
        impl = JSON()
        cache_ok = False

        def __init__(
            self,
            json_encoder: t.Any = json,
        ):
            self.json_encoder = json_encoder
            super().__init__()

        def bind_processor(self, dialect: Dialect) -> _BindProcessorType[T] | None:
            impl_processor = self.impl.bind_processor(dialect)
            if impl_processor:

                def process(value: T | None) -> T | None:
                    if value is not None:
                        if isinstance(pydantic_type, ModelMetaclass):
                            value_to_dump = pydantic_type.model_validate(value)
                        else:
                            value_to_dump = value
                        value = jsonable_encoder(value_to_dump)
                    return impl_processor(value)

            else:

                def process(value: T | None) -> T | None:
                    if isinstance(pydantic_type, ModelMetaclass):
                        value_to_dump = pydantic_type.model_validate(value)
                    else:
                        value_to_dump = value
                    return jsonable_encoder(value_to_dump)

            return process

        def result_processor(
            self,
            dialect: Dialect,
            coltype: object,
        ) -> _ResultProcessorType[T] | None:
            impl_processor = self.impl.result_processor(dialect, coltype)
            if impl_processor:

                def process(value: T) -> T | None:
                    value = impl_processor(value)
                    if value is None:
                        return None

                    if isinstance(value, str):
                        value = json.loads(value)

                    return TypeAdapter(pydantic_type).validate_python(value)

            else:

                def process(value: T) -> T | None:
                    if value is None:
                        return None

                    if isinstance(value, str):
                        value = json.loads(value)

                    return TypeAdapter(pydantic_type).validate_python(value)

            return process

        def compare_values(self, x: t.Any, y: t.Any) -> bool:
            return x == y

    return PydanticJSONType


class MyModel(BaseModel):
    name: str | None = None


class ComplexModel(SQLModel, table=True):
    id: t.Annotated[
        int | None,
        Field(
            default=None,
            primary_key=True,
        ),
    ] = None
    my_model: t.Annotated[
        MyModel | None,
        Field(
            sa_column=Column(pydantic_column_type(MyModel)),
        ),
    ] = None


class Hero(SQLModel, table=True):
    __tablename__ = "hero"
    id: int | None = Field(default=None, primary_key=True)
    hero_type: str = Field(default="hero")

    __mapper_args__ = {
        "polymorphic_on": "hero_type",
        "polymorphic_identity": "hero",
    }


class DarkHero(Hero):
    dark_power: str = Field(
        default="dark",
        sa_column=mapped_column(
            nullable=False, use_existing_column=True, default="dark"
        ),
    )

    __mapper_args__ = {
        "polymorphic_identity": "dark",
    }


engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
    hero = Hero()
    db.add(hero)
    dark_hero = DarkHero(dark_power="pokey")
    db.add(dark_hero)
    db.commit()
    statement = select(DarkHero)
    result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)

Corresponding error code

python test.py
Traceback (most recent call last):
  File "/Users/guhur/src/argile-lib-python/test.py", line 101, in <module>
    class DarkHero(Hero):
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/sqlmodel/main.py", line 542, in __new__
    new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 202, in __new__
    complete_model_class(
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 572, in complete_model_class
    generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 159, in generate_pydantic_signature
    merged_params = _generate_signature_parameters(init, fields, config_wrapper)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 115, in _generate_signature_parameters
    kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/fields.py", line 546, in get_default
    return _utils.smart_deepcopy(self.default)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_utils.py", line 318, in smart_deepcopy
    return deepcopy(obj)  # slowest way when we actually might need a deepcopy
           ^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
    y = func(*args)
        ^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
    args = (deepcopy(arg, memo) for arg in args)
            ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
    y = func(*args)
        ^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
    args = (deepcopy(arg, memo) for arg in args)
            ^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 206, in _deepcopy_list
    append(deepcopy(a, memo))
           ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
         ^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 161, in deepcopy
    rv = reductor(4)
         ^^^^^^^^^^^
TypeError: cannot pickle 'module' object

@adsharma
Copy link

This could help with a different kind of polymorphism. Details here.

Specifically:

@fquery.sqlmodel.model()
@dataclass
class Hero:
  ...

Creates two classes Hero (dataclass) and HereoSQLModel (sqlmodel).

Using polymorphism, we could allow the caller to return Hero (dataclass, cheaper when static typing is good enough) or HeroSQLModel if runtime validation is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants