Skip to content

Commit

Permalink
fix: favour SA mapped type over impl type
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Mar 23, 2024
1 parent 719495e commit 89309bd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
9 changes: 4 additions & 5 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type:
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
annotation = (
column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues]
if hasattr(column.type, "impl")
else column.type.python_type
)
try:
annotation = column.type.python_type
except NotImplementedError:
annotation = column.type.impl.python_type # type: ignore[attr-defined]

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]
Expand Down
48 changes: 45 additions & 3 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Type
from typing import Any, Callable, Type, Union
from uuid import UUID

import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types
Expand Down Expand Up @@ -141,8 +142,7 @@ class Model(Base):
def triple_age(self) -> int:
return self.age * 3 # type: ignore[no-any-return]

class ModelFactory(SQLAlchemyFactory[Model]):
...
class ModelFactory(SQLAlchemyFactory[Model]): ...

instance = ModelFactory.build()
assert isinstance(instance, Model)
Expand Down Expand Up @@ -347,3 +347,45 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]):

result = ModelFactory.build()
assert isinstance(result.name, str)


@pytest.mark.parametrize("python_type_", (UUID, None))
@pytest.mark.parametrize(
"impl_",
(
types.Uuid(),
types.Uuid(native_uuid=False),
types.CHAR(32),
),
)
def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine, python_type_: Union[type, None]) -> None:
class CustomType(types.TypeDecorator):
impl = impl_
cache_ok = True

if python_type_ is not None:

@property
def python_type(self) -> type:
return python_type_

class Base(orm.DeclarativeBase):
type_annotation_map = {
UUID: CustomType,
}

class Model(Base):
__tablename__ = "model_with_custom_types"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
custom_type: orm.Mapped[UUID] = orm.mapped_column(type_=CustomType(), nullable=False)
custom_type_from_annotation_map: orm.Mapped[UUID]

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()

expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type
assert isinstance(instance.custom_type, expected_type)
assert isinstance(instance.custom_type_from_annotation_map, expected_type)

0 comments on commit 89309bd

Please sign in to comment.