-
-
Notifications
You must be signed in to change notification settings - Fork 688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature] Add support for SQLAlchemy polymorphic models #1226
base: main
Are you sure you want to change the base?
[feature] Add support for SQLAlchemy polymorphic models #1226
Conversation
@tiangolo |
Co-authored-by: John Pocock <John-P@users.noreply.github.com>
We are also exploring using SQLModel in our products. This would be quite an ease of life in how we are building our stack. @tiangolo do you have a timeline as to when could this be merged / what needs to be done ? |
Thanks a lot for this PR! We would love to add this feature in our codebase. Unfortunately, we could not use this PR along with a custom type. @PaleNeutron would you mind checking this MRE? (1) the code works fine if you comment (2) however, it fails if both are in the module! Code
import json
import typing as t
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, TypeAdapter
# Warning: we import a deprecated class from the `pydantic` package
# See: https://github.com/pydantic/pydantic/issues/6381
from pydantic._internal._model_construction import ModelMetaclass # noqa: PLC2701
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import mapped_column
from sqlalchemy.sql.type_api import _BindProcessorType, _ResultProcessorType
from sqlmodel import (
JSON,
Column,
Field,
Session,
SQLModel,
TypeDecorator,
create_engine,
select,
)
def pydantic_column_type( # noqa: C901
pydantic_type: type[t.Any],
) -> type[TypeDecorator]:
"""
See details here:
https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
"""
T = t.TypeVar("T")
class PydanticJSONType(TypeDecorator, t.Generic[T]):
impl = JSON()
cache_ok = False
def __init__(
self,
json_encoder: t.Any = json,
):
self.json_encoder = json_encoder
super().__init__()
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[T] | None:
impl_processor = self.impl.bind_processor(dialect)
if impl_processor:
def process(value: T | None) -> T | None:
if value is not None:
if isinstance(pydantic_type, ModelMetaclass):
value_to_dump = pydantic_type.model_validate(value)
else:
value_to_dump = value
value = jsonable_encoder(value_to_dump)
return impl_processor(value)
else:
def process(value: T | None) -> T | None:
if isinstance(pydantic_type, ModelMetaclass):
value_to_dump = pydantic_type.model_validate(value)
else:
value_to_dump = value
return jsonable_encoder(value_to_dump)
return process
def result_processor(
self,
dialect: Dialect,
coltype: object,
) -> _ResultProcessorType[T] | None:
impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:
def process(value: T) -> T | None:
value = impl_processor(value)
if value is None:
return None
if isinstance(value, str):
value = json.loads(value)
return TypeAdapter(pydantic_type).validate_python(value)
else:
def process(value: T) -> T | None:
if value is None:
return None
if isinstance(value, str):
value = json.loads(value)
return TypeAdapter(pydantic_type).validate_python(value)
return process
def compare_values(self, x: t.Any, y: t.Any) -> bool:
return x == y
return PydanticJSONType
class MyModel(BaseModel):
name: str | None = None
class ComplexModel(SQLModel, table=True):
id: t.Annotated[
int | None,
Field(
default=None,
primary_key=True,
),
] = None
my_model: t.Annotated[
MyModel | None,
Field(
sa_column=Column(pydantic_column_type(MyModel)),
),
] = None
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: int | None = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")
__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}
class DarkHero(Hero):
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)
__mapper_args__ = {
"polymorphic_identity": "dark",
}
engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero(dark_power="pokey")
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str) Corresponding error code
python test.py
Traceback (most recent call last):
File "/Users/guhur/src/argile-lib-python/test.py", line 101, in <module>
class DarkHero(Hero):
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/sqlmodel/main.py", line 542, in __new__
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 202, in __new__
complete_model_class(
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_model_construction.py", line 572, in complete_model_class
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 159, in generate_pydantic_signature
merged_params = _generate_signature_parameters(init, fields, config_wrapper)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_signature.py", line 115, in _generate_signature_parameters
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/fields.py", line 546, in get_default
return _utils.smart_deepcopy(self.default)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/guhur/Library/Caches/pypoetry/virtualenvs/argile-lib-python-RxGRaJe1-py3.11/lib/python3.11/site-packages/pydantic/_internal/_utils.py", line 318, in smart_deepcopy
return deepcopy(obj) # slowest way when we actually might need a deepcopy
^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
y = func(*args)
^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
args = (deepcopy(arg, memo) for arg in args)
^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 265, in _reconstruct
y = func(*args)
^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 264, in <genexpr>
args = (deepcopy(arg, memo) for arg in args)
^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 206, in _deepcopy_list
append(deepcopy(a, memo))
^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 211, in <listcomp>
y = [deepcopy(a, memo) for a in x]
^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 271, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 146, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 231, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/homebrew/Cellar/python@3.11/3.11.8/Frameworks/Python.framework/Versions/3.11/lib/python3.11/copy.py", line 161, in deepcopy
rv = reductor(4)
^^^^^^^^^^^
TypeError: cannot pickle 'module' object
|
This could help with a different kind of polymorphism. Details here. Specifically:
Creates two classes Using polymorphism, we could allow the caller to return |
Introduce support for SQLAlchemy polymorphic models by adjusting field defaults and handling inheritance correctly in the SQLModel metaclass. Add tests to verify functionality with polymorphic joined and single table inheritance. Refer to #36 .