Skip to content

Commit

Permalink
feat: update sqlalchemy writer for v2 (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmbhm1 authored Aug 1, 2024
1 parent 0e0b160 commit 1c5b9a5
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 47 deletions.
73 changes: 73 additions & 0 deletions supabase_pydantic/util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ModelGenerationType(str, Enum):
'serial': ('int', None),
'bigserial': ('int', None),
'money': ('Decimal', 'from decimal import Decimal'),
'character varying': ('str', None),
'character varying(n)': ('str', None),
'varchar(n)': ('str', None),
'character(n)': ('str', None),
Expand Down Expand Up @@ -137,6 +138,7 @@ class ModelGenerationType(str, Enum):
'serial': ('Integer', 'from sqlalchemy import Integer'), # Auto increment in context
'bigserial': ('BigInteger', 'from sqlalchemy import BigInteger'), # Auto increment in context
'money': ('Numeric', 'from sqlalchemy import Numeric'), # No specific Money type in SQLAlchemy
'character varying': ('String', 'from sqlalchemy import String'),
'character varying(n)': ('String', 'from sqlalchemy import String'),
'varchar(n)': ('String', 'from sqlalchemy import String'),
'character(n)': ('String', 'from sqlalchemy import String'),
Expand Down Expand Up @@ -175,6 +177,77 @@ class ModelGenerationType(str, Enum):
}


SQLALCHEMY_V2_TYPE_MAP: dict[str, tuple[str, str | None]] = {
'integer': ('Integer,int', 'from sqlalchemy import Integer'),
'bigint': ('BigInteger,int', 'from sqlalchemy import BigInteger'),
'smallint': ('SmallInteger,int', 'from sqlalchemy import SmallInteger'),
'numeric': ('Numeric,float', 'from sqlalchemy import Numeric'),
'decimal': ('Numeric,float', 'from sqlalchemy import Numeric'),
'real': ('Float,float', 'from sqlalchemy import Float'),
'double precision': ('Float,float', 'from sqlalchemy import Float'),
'serial': ('Integer,int', 'from sqlalchemy import Integer'), # Auto increment in context
'bigserial': ('BigInteger,int', 'from sqlalchemy import BigInteger'), # Auto increment in context
'money': (
'Numeric,Decimal',
'from sqlalchemy import Numeric\nfrom decimal import Decimal',
), # No specific Money type in SQLAlchemy
'character varying': ('String,str', 'from sqlalchemy import String'),
'character varying(n)': ('String,str', 'from sqlalchemy import String'),
'varchar(n)': ('String,str', 'from sqlalchemy import String'),
'character(n)': ('String,str', 'from sqlalchemy import String'),
'char(n)': ('String,str', 'from sqlalchemy import String'),
'text': ('Text,str', 'from sqlalchemy import Text'),
'bytea': ('LargeBinary,bytes', 'from sqlalchemy import LargeBinary'),
'timestamp': ('DateTime,datetime', 'from sqlalchemy import DateTime\nfrom datetime import datetime'),
'timestamp with time zone': (
'DateTime,datetime',
'from sqlalchemy.dialects.postgresql import TIMESTAMP\nfrom datetime import datetime',
),
'timestamp without time zone': (
'DateTime,datetime',
'from sqlalchemy import DateTime\nfrom datetime import datetime',
),
'date': ('Date,date', 'from sqlalchemy import Date\nfrom datetime import date'),
'time': ('Time,time', 'from sqlalchemy import Time\nfrom datetime import time'),
'time with time zone': (
'Time,datetime.time',
'from sqlalchemy.dialects.postgresql import TIME\nfrom datetime import time',
),
'interval': ('Interval,timedelta', 'from sqlalchemy import Interval\nfrom datetime import timedelta'),
'boolean': ('Boolean,bool', 'from sqlalchemy import Boolean'),
'enum': ('Enum,str', 'from sqlalchemy import Enum'), # Enums need specific handling based on defined values
'point': (
'PickleType,Tuple[float, float]',
'from sqlalchemy import PickleType\nfrom typeing import Tuple',
), # No direct mapping, custom handling
'line': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'lseg': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'box': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'path': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'polygon': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'circle': ('PickleType,Any', 'from sqlalchemy import PickleType\nfrom typing import Any'),
'cidr': (
'CIDR,IPv4Network',
'from sqlalchemy.dialects.postgresql import CIDR\nfrom ipaddress import IPv4Network',
),
'inet': (
'INET,IPv4Address | IPv6Address',
'from sqlalchemy.dialects.postgresql import INET\nfrom ipaddress import IPv4Address, IPv6Address',
),
'macaddr': ('MACADDR,str', 'from sqlalchemy.dialects.postgresql import MACADDR'),
'macaddr8': ('MACADDR8,str', 'from sqlalchemy.dialects.postgresql import MACADDR8'),
'bit': ('BIT,str', 'from sqlalchemy.dialects.postgresql import BIT'),
'bit varying': ('BIT,str', 'from sqlalchemy.dialects.postgresql import BIT'),
'tsvector': ('TSVECTOR,str', 'from sqlalchemy.dialects.postgresql import TSVECTOR'),
'tsquery': ('TSQUERY,str', 'from sqlalchemy.dialects.postgresql import TSQUERY'),
'uuid': ('UUID,UUID4', 'from sqlalchemy.dialects.postgresql import UUID\nfrom pydantic import UUID4'),
'xml': ('Text,str', 'from sqlalchemy import Text'), # XML handled as Text for simplicity
'json': ('JSON,dict | Json', 'from sqlalchemy import JSON\nfrom pydantic import Json'),
'jsonb': ('JSONB,dict | Json', 'from sqlalchemy.dialects.postgresql import JSONB\nfrom pydantic import Json'),
'ARRAY': ('ARRAY,list', 'from sqlalchemy.dialects.postgresql import ARRAY'), # Generic ARRAY; specify further
}


# Queries

GET_TABLE_COLUMN_DETAILS = """
Expand Down
16 changes: 14 additions & 2 deletions supabase_pydantic/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from supabase_pydantic.util.constants import (
PYDANTIC_TYPE_MAP,
SQLALCHEMY_TYPE_MAP,
SQLALCHEMY_V2_TYPE_MAP,
FrameWorkType,
OrmType,
WriterConfig,
Expand Down Expand Up @@ -58,7 +59,9 @@ def get_enum_member_from_string(cls: Any, value: str) -> Any:


def adapt_type_map(
postgres_type: str, default_type: tuple[str, str | None], type_map: dict[str, tuple[str, str | None]]
postgres_type: str,
default_type: tuple[str, str | None],
type_map: dict[str, tuple[str, str | None]],
) -> tuple[str, str | None]:
"""Adapt a PostgreSQL data type to a Pydantic and SQLAlchemy type."""
array_suffix = '[]'
Expand All @@ -76,12 +79,21 @@ def adapt_type_map(


def get_sqlalchemy_type(
postgres_type: str, default: tuple[str, str | None] = ('String', None)
postgres_type: str, default: tuple[str, str | None] = ('String', 'from sqlalchemy import String')
) -> tuple[str, str | None]:
"""Get the SQLAlchemy type from the PostgreSQL type."""
return adapt_type_map(postgres_type, default, SQLALCHEMY_TYPE_MAP)


def get_sqlalchemy_v2_type(
postgres_type: str, default: tuple[str, str | None] = ('String,str', 'from sqlalchemy import String')
) -> tuple[str, str, str | None]:
"""Get the SQLAlchemy v2 type from the PostgreSQL type."""
both_types, imports = adapt_type_map(postgres_type, default, SQLALCHEMY_V2_TYPE_MAP)
sql, py = both_types.split(',')
return (sql, py, imports)


def get_pydantic_type(postgres_type: str, default: tuple[str, str | None] = ('Any', None)) -> tuple[str, str | None]:
"""Get the Pydantic type from the PostgreSQL type."""
return adapt_type_map(postgres_type, default, PYDANTIC_TYPE_MAP)
Expand Down
33 changes: 22 additions & 11 deletions supabase_pydantic/util/writers/sqlalchemy_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from supabase_pydantic.util.constants import RelationType
from supabase_pydantic.util.dataclasses import ColumnInfo, SortedColumns, TableInfo
from supabase_pydantic.util.util import get_sqlalchemy_type, to_pascal_case
from supabase_pydantic.util.util import get_sqlalchemy_v2_type, to_pascal_case
from supabase_pydantic.util.writers.abstract_classes import AbstractClassWriter, AbstractFileWriter
from supabase_pydantic.util.writers.util import get_section_comment

Expand Down Expand Up @@ -33,31 +33,31 @@ def write_docs(self) -> str:
def write_column(self, c: ColumnInfo) -> str:
"""Method to generate column definition for the class."""
# base type
base_type = get_sqlalchemy_type(c.post_gres_datatype)[0]
base_type, pyth_type, _ = get_sqlalchemy_v2_type(c.post_gres_datatype)
if base_type.lower() == 'uuid':
base_type = 'UUID(as_uuid=True)'
if 'time zone' in c.post_gres_datatype.lower():
base_type = 'TIMESTAMP(timezone=True)'
col_dtype = f'{pyth_type}' + (' | None' if c.is_nullable else '')

# field values
field_values = dict()
field_values_list_first = list()
if c.is_nullable:
field_values['nullable'] = 'True'
if c.primary:
field_values['primary_key'] = 'True'
if c.is_unique:
field_values['unique'] = 'True'
for fk in self.table.foreign_keys:
if c.name == fk.column_name:
field_values_list_first.append(f'ForeignKey("{fk.foreign_table_name}.{fk.foreign_column_name}")')

field_values_string = ', '.join(field_values_list_first) if len(field_values_list_first) > 0 else ''
if len(field_values) > 0:
if len(field_values_list_first) > 0:
field_values_string += ', '
field_values_string += ', '.join([f'{k}={v}' for k, v in field_values.items()])

return f'{c.name} = Column({base_type}{", " + field_values_string if (field_values_string is not None and bool(field_values_string)) else ""})' # noqa: E501
return f'{c.name}: Mapped[{col_dtype}] = mapped_column({base_type}{", " + field_values_string if (field_values_string is not None and bool(field_values_string)) else ""})' # noqa: E501

def write_primary_keys(self) -> str | None:
"""Method to generate primary key definitions for the class."""
Expand Down Expand Up @@ -106,12 +106,12 @@ def __init__(
super().__init__(tables, file_path, writer)

def _dt_imports(
self, imports: set, default_import: tuple[Any, Any | None] = ('String', 'from sqlalchemy import String')
self, imports: set, default_import: tuple[Any, Any | None] = ('String,str', 'from sqlalchemy import String')
) -> None:
"""Update the imports with the necessary data types."""

def _pyi(c: ColumnInfo) -> str | None: # pyi = pydantic import # noqa
return get_sqlalchemy_type(c.post_gres_datatype, default_import)[1]
return get_sqlalchemy_v2_type(c.post_gres_datatype, default_import)[2]

# column data types
imports.update(filter(None, map(_pyi, (c for t in self.tables for c in t.columns))))
Expand All @@ -120,17 +120,22 @@ def write_imports(self) -> str:
"""Method to generate the imports for the file."""
# standard
imports = {
'from sqlalchemy.ext.declarative import declarative_base',
'from sqlalchemy import Column',
'from sqlalchemy import ForeignKey',
'from sqlalchemy.orm import DeclarativeBase',
'from sqlalchemy.orm import Mapped',
'from sqlalchemy.orm import mapped_column',
}
if any([len(t.primary_key()) > 0 for t in self.tables]):
imports.add('from sqlalchemy import PrimaryKeyConstraint')

# column data types
self._dt_imports(imports)

return '\n'.join(sorted(imports))
new_imports = set()
for i in imports:
new_imports.update(i.split('\n'))

return '\n'.join(sorted(new_imports))

def _class_writer_helper(
self,
Expand All @@ -157,9 +162,15 @@ def _method(t: TableInfo) -> Any:

def write_custom_classes(self, add_fk: bool = False) -> str:
"""Method to write the complete class definition."""
declarative_base_class = (
'class Base(DeclarativeBase):\n\t'
'"""Declarative Base Class."""\n\t'
'# type_annotation_map = {}\n\n\t'
'pass'
)
return self._class_writer_helper(
comment_title='Declarative Base',
classes_override=['Base = declarative_base()'],
classes_override=[declarative_base_class],
)

def write_base_classes(self) -> str:
Expand Down
Loading

0 comments on commit 1c5b9a5

Please sign in to comment.