Skip to content

Commit

Permalink
feat: add ability to inherit from all-null parent classes in pydantic…
Browse files Browse the repository at this point in the history
… fastapi models (#38)
  • Loading branch information
kmbhm1 authored Aug 5, 2024
1 parent e03fb3e commit 6c6c673
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 47 deletions.
2 changes: 2 additions & 0 deletions docs/other/to-do.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@
- [ ] Add fake generator for inserts and seed data
- [ ] Add supabase_secret key connection method
- [ ] Add mysql and other conns ...
- [x] Add option for nullified parent classe
- [ ] Change versioning to timestamp and latest
23 changes: 19 additions & 4 deletions supabase_pydantic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
load_dotenv(find_dotenv())

# Standard choices
model_choices = ['pydantic', 'sqlalchemy']
framework_choices = ['fastapi', 'fastapi-jsonapi']
# model_choices = ['pydantic', 'sqlalchemy']
# framework_choices = ['fastapi', 'fastapi-jsonapi']
model_choices = ['pydantic']
framework_choices = ['fastapi']


def check_readiness(env_vars: dict[str, str | None]) -> bool:
Expand Down Expand Up @@ -144,6 +146,7 @@ def clean(ctx: Any, directory: str) -> None:
'models',
multiple=True,
default=['pydantic'],
show_default=True,
type=click.Choice(model_choices, case_sensitive=False),
required=False,
help='The model type to generate. This can be a space separated list of valid model types. Default is "pydantic".',
Expand All @@ -154,6 +157,7 @@ def clean(ctx: Any, directory: str) -> None:
'frameworks',
multiple=True,
default=['fastapi'],
show_default=True,
type=click.Choice(framework_choices, case_sensitive=False),
required=False,
help='The framework to generate code for. This can be a space separated list of valid frameworks. Default is "fastapi".', # noqa: E501
Expand All @@ -174,16 +178,25 @@ def clean(ctx: Any, directory: str) -> None:
'default_directory',
multiple=False,
default='entities',
show_default=True,
type=click.Path(exists=False, file_okay=False, dir_okay=True, resolve_path=True),
help='The directory to save files. Defaults to "entities".',
help='The directory to save files',
required=False,
)
@click.option('--overwrite/--no-overwrite', default=True, help='Overwrite existing files. Defaults to overwrite.')
@click.option(
'--null-parent-classes',
is_flag=True,
show_default=True,
default=False,
help='In addition to the generated base classes, generate null parent classes for those base classes. For Pydantic models only.', # noqa: E501
)
def gen(
models: tuple[str],
frameworks: tuple[str],
default_directory: str,
overwrite: bool,
null_parent_classes: bool,
local: bool = False,
# linked: bool = False,
db_url: str | None = None,
Expand Down Expand Up @@ -243,7 +256,9 @@ def gen(
factory = FileWriterFactory()
for job, c in jobs.items(): # c = config
print(f'Generating {job} models...')
p = factory.get_file_writer(tables, c.fpath(), c.file_type, c.framework_type).save(overwrite)
p = factory.get_file_writer(
tables, c.fpath(), c.file_type, c.framework_type, add_null_parent_classes=null_parent_classes
).save(overwrite)
paths.append(p)
print(f'{job} models generated successfully: {p}')

Expand Down
8 changes: 8 additions & 0 deletions supabase_pydantic/util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from typing import TypedDict


class WriterClassType(Enum):
"""Enum for writer class types."""

BASE = 'base'
BASE_WITH_PARENT = 'base_with_parent'
PARENT = 'parent'


class DatabaseConnectionType(Enum):
"""Enum for database connection types."""

Expand Down
23 changes: 18 additions & 5 deletions supabase_pydantic/util/writers/abstract_classes.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from abc import ABC, abstractmethod
from pathlib import Path

from supabase_pydantic.util.constants import BASE_CLASS_POSTFIX
from supabase_pydantic.util.constants import BASE_CLASS_POSTFIX, WriterClassType
from supabase_pydantic.util.dataclasses import TableInfo
from supabase_pydantic.util.util import to_pascal_case
from supabase_pydantic.util.writers.util import generate_unique_filename


class AbstractClassWriter(ABC):
def __init__(self, table: TableInfo, nullify_base_schema_class: bool = False):
def __init__(
self, table: TableInfo, class_type: WriterClassType = WriterClassType.BASE, null_defaults: bool = False
):
self.table = table
self.nullify_base_schema_class = nullify_base_schema_class
self.class_type = class_type
self._null_defaults = null_defaults
self.name = to_pascal_case(self.table.name)

@staticmethod
def _proper_name(name: str, use_base: bool = False) -> str:
return to_pascal_case(name) + (BASE_CLASS_POSTFIX if use_base else '')

def write_class(self, add_fk: bool = False) -> str:
def write_class(
self,
add_fk: bool = False,
) -> str:
"""Method to write the complete class definition."""
return self.write_definition() + self.write_docs() + self.write_columns(add_fk)

Expand Down Expand Up @@ -77,9 +83,16 @@ def write_columns(self, add_fk: bool = False) -> str:


class AbstractFileWriter(ABC):
def __init__(self, tables: list[TableInfo], file_path: str, writer: type[AbstractClassWriter]):
def __init__(
self,
tables: list[TableInfo],
file_path: str,
writer: type[AbstractClassWriter],
add_null_parent_classes: bool = False,
):
self.tables = tables
self.file_path = file_path
self.add_null_parent_classes = add_null_parent_classes
self.writer = writer
self.jstr = '\n\n\n'

Expand Down
4 changes: 3 additions & 1 deletion supabase_pydantic/util/writers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def get_file_writer(
file_path: str,
file_type: OrmType = OrmType.PYDANTIC,
framework_type: FrameWorkType = FrameWorkType.FASTAPI,
add_null_parent_classes: bool = False,
) -> AbstractFileWriter:
"""Get the file writer based on the provided parameters.
Expand All @@ -20,6 +21,7 @@ def get_file_writer(
file_path (str): The file path.
file_type (OrmType, optional): The ORM type. Defaults to OrmType.PYDANTIC.
framework_type (FrameWorkType, optional): The framework type. Defaults to FrameWorkType.FASTAPI.
add_null_parent_classes (bool, optional): Add null parent classes for base classes. Defaults to False.
Returns:
The file writer instance.
Expand All @@ -30,7 +32,7 @@ def get_file_writer(
case OrmType.SQLALCHEMY, FrameWorkType.FASTAPI_JSONAPI:
return SqlAlchemyJSONAPIWriter(tables, file_path)
case OrmType.PYDANTIC, FrameWorkType.FASTAPI:
return PydanticFastAPIWriter(tables, file_path)
return PydanticFastAPIWriter(tables, file_path, add_null_parent_classes=add_null_parent_classes)
case OrmType.PYDANTIC, FrameWorkType.FASTAPI_JSONAPI:
return PydanticJSONAPIWriter(tables, file_path)
case _:
Expand Down
72 changes: 55 additions & 17 deletions supabase_pydantic/util/writers/pydantic_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
CUSTOM_JSONAPI_META_MODEL_NAME,
CUSTOM_MODEL_NAME,
RelationType,
WriterClassType,
)
from supabase_pydantic.util.dataclasses import ColumnInfo, ForeignKeyInfo, SortedColumns, TableInfo
from supabase_pydantic.util.util import get_pydantic_type
Expand All @@ -15,30 +16,37 @@


class PydanticFastAPIClassWriter(AbstractClassWriter):
def __init__(self, table: TableInfo, nullify_base_schema_class: bool = False):
super().__init__(table, nullify_base_schema_class)
def __init__(
self, table: TableInfo, class_type: WriterClassType = WriterClassType.BASE, null_defaults: bool = False
):
super().__init__(table, class_type, null_defaults)

self.separated_columns: SortedColumns = self.table.sort_and_separate_columns(
separate_nullable=False, separate_primary_key=True
)

def write_name(self) -> str:
"""Method to generate the header for the base class."""
return f'{self.name}' + f'{post(self.table.table_type, self.nullify_base_schema_class)}'
return f'{self.name}' + f'{post(self.table.table_type, self.class_type)}'

def write_metaclass(self, metaclasses: list[str] | None = None) -> str | None:
"""Method to generate the metaclasses for the class."""
if metaclasses is not None and (isinstance(metaclasses, list), len(metaclasses) > 0):
metaclasses = metaclasses or []
if len(metaclasses) > 0:
return ', '.join(metaclasses)
return CUSTOM_MODEL_NAME
if self.class_type == WriterClassType.PARENT or self.class_type == WriterClassType.BASE:
return CUSTOM_MODEL_NAME
else:
metaclasses.append(f'{self.name}{post(self.table.table_type, WriterClassType.PARENT)}')
return ', '.join(metaclasses)

def write_column(self, c: ColumnInfo) -> str:
"""Method to generate the column definition for the class."""
base_type = get_pydantic_type(c.post_gres_datatype, ('str', None))[0]
base_type = f'{base_type} | None' if (c.is_nullable or self.nullify_base_schema_class) else base_type
base_type = f'{base_type} | None' if (c.is_nullable or self._null_defaults) else base_type

field_values = dict()
if (c.is_nullable is not None and c.is_nullable) or self.nullify_base_schema_class:
if (c.is_nullable is not None and c.is_nullable) or self._null_defaults:
field_values['default'] = 'None'
if c.alias is not None:
field_values['alias'] = f'"{c.alias}"'
Expand All @@ -51,7 +59,7 @@ def write_column(self, c: ColumnInfo) -> str:

def write_docs(self) -> str:
"""Method to generate the docstrings for the class."""
qualifier = 'Nullable Base' if self.nullify_base_schema_class else 'Base'
qualifier = '(Nullable) Parent' if self._null_defaults else 'Base'
return f'\n\t"""{self.name} {qualifier} Schema."""\n\n'

def write_primary_keys(self) -> str | None:
Expand Down Expand Up @@ -101,9 +109,13 @@ def write_operational_class(self) -> str | None:

class PydanticFastAPIWriter(AbstractFileWriter):
def __init__(
self, tables: list[TableInfo], file_path: str, writer: type[AbstractClassWriter] = PydanticFastAPIClassWriter
self,
tables: list[TableInfo],
file_path: str,
writer: type[AbstractClassWriter] = PydanticFastAPIClassWriter,
add_null_parent_classes: bool = False,
):
super().__init__(tables, file_path, writer)
super().__init__(tables, file_path, writer, add_null_parent_classes)

def _dt_imports(self, imports: set, default_import: tuple[Any, Any | None] = (Any, None)) -> None:
"""Update the imports with the necessary data types."""
Expand Down Expand Up @@ -135,18 +147,24 @@ def _class_writer_helper(
comments: list[str] = [],
classes_override: list[str] = [],
is_base: bool = True,
class_type: WriterClassType = WriterClassType.BASE,
**kwargs: Any,
) -> str:
sxn = get_section_comment(comment_title, comments)
classes = classes_override

if len(classes_override) == 0:
attr = 'write_class' if is_base else 'write_operational_class'

def _method(t: TableInfo) -> Any:
if class_type == WriterClassType.PARENT:
return getattr(self.writer(t, class_type, True), attr)
elif class_type == WriterClassType.BASE_WITH_PARENT:
return getattr(self.writer(t, class_type, False), attr)
return getattr(self.writer(t), attr)

if 'add_fk' in kwargs:
classes = [_method(t)(add_fk=kwargs['add_fk']) for t in self.tables]
if len(kwargs) > 0:
classes = [_method(t)(**kwargs) for t in self.tables]
else:
classes = [_method(t)() for t in self.tables]

Expand All @@ -162,7 +180,19 @@ def write_custom_classes(self) -> str | None:
)

def write_base_classes(self) -> str:
"""Method to generate the base classes for the file."""
"""Method to generate the base & parent classes for the file."""
if self.add_null_parent_classes:
return (
self._class_writer_helper(
'Parent Classes',
comments=[
'This is a parent class with all fields as nullable. This is useful for refining your models with inheritance. See https://stackoverflow.com/a/65907609.' # noqa: E501
],
class_type=WriterClassType.PARENT,
)
+ '\n'
) + self._class_writer_helper('Base Classes', class_type=WriterClassType.BASE_WITH_PARENT)

return self._class_writer_helper('Base Classes')

def write_operational_classes(self) -> str | None:
Expand All @@ -174,8 +204,10 @@ def write_operational_classes(self) -> str | None:


class PydanticJSONAPIClassWriter(PydanticFastAPIClassWriter):
def __init__(self, table: TableInfo, nullify_base_schema_class: bool = False):
super().__init__(table, nullify_base_schema_class)
def __init__(
self, table: TableInfo, class_type: WriterClassType = WriterClassType.BASE, null_defaults: bool = False
):
super().__init__(table, class_type, null_defaults)

def write_foreign_columns(self, use_base: bool = True) -> str | None:
"""Method to generate foreign column definitions for the class."""
Expand Down Expand Up @@ -233,8 +265,14 @@ def write_columns(self, add_fk: bool = False) -> str:


class PydanticJSONAPIWriter(PydanticFastAPIWriter):
def __init__(self, tables: list[TableInfo], file_path: str):
super().__init__(tables, file_path, PydanticJSONAPIClassWriter)
def __init__(
self,
tables: list[TableInfo],
file_path: str,
writer: type[AbstractClassWriter] = PydanticJSONAPIClassWriter,
add_null_parent_classes: bool = False,
):
super().__init__(tables, file_path, writer, add_null_parent_classes)

def write_imports(self) -> str:
"""Method to generate the imports for the file."""
Expand Down
32 changes: 22 additions & 10 deletions supabase_pydantic/util/writers/sqlalchemy_writers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from supabase_pydantic.util.constants import RelationType
from supabase_pydantic.util.constants import RelationType, WriterClassType
from supabase_pydantic.util.dataclasses import ColumnInfo, SortedColumns, TableInfo
from supabase_pydantic.util.util import get_sqlalchemy_v2_type, to_pascal_case
from supabase_pydantic.util.writers.abstract_classes import AbstractClassWriter, AbstractFileWriter
Expand All @@ -10,8 +10,10 @@


class SqlAlchemyFastAPIClassWriter(AbstractClassWriter):
def __init__(self, table: TableInfo, nullify_base_schema_class: bool = False):
super().__init__(table, nullify_base_schema_class)
def __init__(
self, table: TableInfo, class_type: WriterClassType = WriterClassType.BASE, null_defaults: bool = False
):
super().__init__(table, class_type, null_defaults)
self._tname = to_pascal_case(self.table.name)
self.separated_columns: SortedColumns = self.table.sort_and_separate_columns(
separate_nullable=True, separate_primary_key=True
Expand All @@ -27,7 +29,7 @@ def write_metaclass(self, metaclasses: list[str] | None = None) -> str | None:

def write_docs(self) -> str:
"""Method to generate the docstrings for the class."""
qualifier = 'Nullable Base' if self.nullify_base_schema_class else 'Base'
qualifier = 'Nullable Base' if self._null_defaults else 'Base'
return f'\n\t"""{self._tname} {qualifier}."""\n\n\t__tablename__ = "{self.table.name}"\n\n'

def write_column(self, c: ColumnInfo) -> str:
Expand Down Expand Up @@ -101,9 +103,13 @@ def write_columns(self, add_fk: bool = False) -> str:

class SqlAlchemyFastAPIWriter(AbstractFileWriter):
def __init__(
self, tables: list[TableInfo], file_path: str, writer: type[AbstractClassWriter] = SqlAlchemyFastAPIClassWriter
self,
tables: list[TableInfo],
file_path: str,
writer: type[AbstractClassWriter] = SqlAlchemyFastAPIClassWriter,
add_null_parent_classes: bool = False,
):
super().__init__(tables, file_path, writer)
super().__init__(tables, file_path, writer, add_null_parent_classes)

def _dt_imports(
self, imports: set, default_import: tuple[Any, Any | None] = ('String,str', 'from sqlalchemy import String')
Expand Down Expand Up @@ -186,8 +192,10 @@ def write_operational_classes(self) -> str | None:


class SqlAlchemyJSONAPIClassWriter(SqlAlchemyFastAPIClassWriter):
def __init__(self, table: TableInfo, nullify_base_schema_class: bool = False):
super().__init__(table, nullify_base_schema_class)
def __init__(
self, table: TableInfo, class_type: WriterClassType = WriterClassType.BASE, null_defaults: bool = False
):
super().__init__(table, class_type, null_defaults)

def write_foreign_columns(self, use_base: bool = False) -> str | None:
"""Method to generate foreign column definitions for the class."""
Expand All @@ -211,9 +219,13 @@ def write_foreign_columns(self, use_base: bool = False) -> str | None:

class SqlAlchemyJSONAPIWriter(SqlAlchemyFastAPIWriter):
def __init__(
self, tables: list[TableInfo], file_path: str, writer: type[AbstractClassWriter] = SqlAlchemyJSONAPIClassWriter
self,
tables: list[TableInfo],
file_path: str,
writer: type[AbstractClassWriter] = SqlAlchemyJSONAPIClassWriter,
add_null_parent_classes: bool = False,
):
super().__init__(tables, file_path, writer)
super().__init__(tables, file_path, writer, add_null_parent_classes)

def write_imports(self) -> str:
"""Method to generate the imports for the file."""
Expand Down
Loading

0 comments on commit 6c6c673

Please sign in to comment.