diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index 62a7c0a5..5a2155b2 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type: elif issubclass(column_type, types.ARRAY): annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined] else: - annotation = ( - column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues] - if hasattr(column.type, "impl") - else column.type.python_type - ) + try: + annotation = column.type.python_type + except NotImplementedError: + annotation = column.type.impl.python_type # type: ignore[attr-defined] if column.nullable: annotation = Union[annotation, None] # type: ignore[assignment] diff --git a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py index 6a7657da..ee2db4bc 100644 --- a/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py +++ b/tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Type +from typing import Any, Callable, Type, Union +from uuid import UUID import pytest from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types @@ -141,8 +142,7 @@ class Model(Base): def triple_age(self) -> int: return self.age * 3 # type: ignore[no-any-return] - class ModelFactory(SQLAlchemyFactory[Model]): - ... + class ModelFactory(SQLAlchemyFactory[Model]): ... instance = ModelFactory.build() assert isinstance(instance, Model) @@ -347,3 +347,45 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]): result = ModelFactory.build() assert isinstance(result.name, str) + + +@pytest.mark.parametrize("python_type_", (UUID, None)) +@pytest.mark.parametrize( + "impl_", + ( + types.Uuid(), + types.Uuid(native_uuid=False), + types.CHAR(32), + ), +) +def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine, python_type_: Union[type, None]) -> None: + class CustomType(types.TypeDecorator): + impl = impl_ + cache_ok = True + + if python_type_ is not None: + + @property + def python_type(self) -> type: + return python_type_ + + class Base(orm.DeclarativeBase): + type_annotation_map = { + UUID: CustomType, + } + + class Model(Base): + __tablename__ = "model_with_custom_types" + + id: orm.Mapped[int] = orm.mapped_column(primary_key=True) + custom_type: orm.Mapped[UUID] = orm.mapped_column(type_=CustomType(), nullable=False) + custom_type_from_annotation_map: orm.Mapped[UUID] + + class ModelFactory(SQLAlchemyFactory[Model]): + __model__ = Model + + instance = ModelFactory.build() + + expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type + assert isinstance(instance.custom_type, expected_type) + assert isinstance(instance.custom_type_from_annotation_map, expected_type)