Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add support for passing a custom SQLAlchemy type to Field() with sa_type #505

Merged
merged 13 commits into from
Oct 29, 2023
16 changes: 14 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
foreign_key = kwargs.pop("foreign_key", Undefined)
unique = kwargs.pop("unique", False)
index = kwargs.pop("index", Undefined)
sa_type = kwargs.pop("sa_type", Undefined)
sa_column = kwargs.pop("sa_column", Undefined)
sa_column_args = kwargs.pop("sa_column_args", Undefined)
sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined)
Expand Down Expand Up @@ -104,18 +105,23 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
)
if unique is not Undefined:
raise RuntimeError(
"Passing unique is not supported when " "also passing a sa_column"
"Passing unique is not supported when also passing a sa_column"
)
if index is not Undefined:
raise RuntimeError(
"Passing index is not supported when " "also passing a sa_column"
"Passing index is not supported when also passing a sa_column"
)
if sa_type is not Undefined:
raise RuntimeError(
"Passing sa_type is not supported when also passing a sa_column"
)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.nullable = nullable
self.foreign_key = foreign_key
self.unique = unique
self.index = index
self.sa_type = sa_type
self.sa_column = sa_column
self.sa_column_args = sa_column_args
self.sa_column_kwargs = sa_column_kwargs
Expand Down Expand Up @@ -185,6 +191,7 @@ def Field(
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
schema_extra: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -264,6 +271,7 @@ def Field(
unique: Union[bool, UndefinedType] = Undefined,
nullable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
sa_type: Union[Type[Any], UndefinedType] = Undefined,
sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore
sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined,
sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined,
Expand Down Expand Up @@ -300,6 +308,7 @@ def Field(
unique=unique,
nullable=nullable,
index=index,
sa_type=sa_type,
sa_column=sa_column,
sa_column_args=sa_column_args,
sa_column_kwargs=sa_column_kwargs,
Expand Down Expand Up @@ -515,6 +524,9 @@ def __init__(


def get_sqlalchemy_type(field: ModelField) -> Any:
sa_type = getattr(field.field_info, "sa_type", Undefined) # noqa: B009
if sa_type is not Undefined:
return sa_type
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
if issubclass(field.type_, Enum):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_field_sa_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ class Item(SQLModel, table=True):
)


def test_sa_column_no_type() -> None:
with pytest.raises(RuntimeError):

class Item(SQLModel, table=True):
id: Optional[int] = Field(
default=None,
sa_type=Integer,
sa_column=Column(Integer, primary_key=True),
)


def test_sa_column_no_primary_key() -> None:
with pytest.raises(RuntimeError):

Expand Down