Skip to content

Commit

Permalink
feat: update version logic to dts and latest schema (#39)
Browse files Browse the repository at this point in the history
* feat: add ability to inherit from all-null parent classes in pydantic fastapi models
* feat: change versioning to datetime
  • Loading branch information
kmbhm1 authored Aug 5, 2024
1 parent d7535e0 commit 06193cb
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 24 deletions.
17 changes: 10 additions & 7 deletions supabase_pydantic/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,14 @@ def clean(ctx: Any, directory: str) -> None:
help='The directory to save files',
required=False,
)
@click.option('--overwrite/--no-overwrite', default=True, help='Overwrite existing files. Defaults to overwrite.')
@click.option(
'--no-overwrite',
'overwrite',
is_flag=True,
show_default=True,
default=False,
help='Overwrite existing files. Defaults to overwrite.',
)
@click.option(
'--null-parent-classes',
is_flag=True,
Expand All @@ -204,10 +211,6 @@ def gen(
) -> None:
"""Generate models from a PostgreSQL database."""
# pp.pprint(locals())
# if dburl is None and project_id is None and not local and not linked:
# print('Please provide a connection source. Exiting...')
# return

# Load environment variables from .env file & check if they are set correctly
if not local and db_url is None:
print('Please provide a connection source. Exiting...')
Expand Down Expand Up @@ -256,10 +259,10 @@ def gen(
factory = FileWriterFactory()
for job, c in jobs.items(): # c = config
print(f'Generating {job} models...')
p = factory.get_file_writer(
p, vf = factory.get_file_writer(
tables, c.fpath(), c.file_type, c.framework_type, add_null_parent_classes=null_parent_classes
).save(overwrite)
paths.append(p)
paths += [p, vf] if vf is not None else [p]
print(f'{job} models generated successfully: {p}')

# Format the generated files
Expand Down
16 changes: 12 additions & 4 deletions supabase_pydantic/util/writers/abstract_classes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path

Expand Down Expand Up @@ -109,15 +110,22 @@ def write(self) -> str:
# filter None and join parts
return self.jstr.join(p for p in parts if p is not None) + '\n'

def save(self, overwrite: bool = False) -> str:
def save(self, overwrite: bool = False) -> tuple[str, str | None]:
"""Method to save the file."""
fp = Path(self.file_path)
base, ext, directory = fp.stem, fp.suffix, str(fp.parent)
p = generate_unique_filename(base, ext, directory) if not overwrite and fp.exists() else self.file_path
with open(p, 'w') as f:
latest_file = os.path.join(directory, f'{base}_latest{ext}')
with open(latest_file, 'w') as f:
f.write(self.write())

return p
if overwrite:
versioned_file = generate_unique_filename(base, ext, directory)
with open(versioned_file, 'w') as f:
f.write(self.write())

return latest_file, versioned_file

return latest_file, None

def join(self, strings: list[str]) -> str:
"""Method to join strings."""
Expand Down
21 changes: 12 additions & 9 deletions supabase_pydantic/util/writers/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone

from supabase_pydantic.util.constants import BASE_CLASS_POSTFIX, WriterClassType
from supabase_pydantic.util.util import chunk_text
Expand All @@ -23,15 +24,17 @@ def generate_unique_filename(base_name: str, extension: str, directory: str = '.
"""
extension = extension.lstrip('.')
file_name = f'{base_name}.{extension}'
file_path = os.path.join(directory, file_name)
i = 1
while os.path.exists(file_path):
file_name = f'{base_name}_{i}.{extension}'
file_path = os.path.join(directory, file_name)
i += 1

return file_path
dt_str = datetime.now(tz=timezone.utc).strftime('%Y%m%d%H%M%S%f')
file_name = f'{base_name}_{dt_str}.{extension}'

# file_path = os.path.join(directory, file_name)
# i = 1
# while os.path.exists(file_path):
# file_name = f'{base_name}_{i}.{extension}'
# file_path = os.path.join(directory, file_name)
# i += 1

return os.path.join(directory, file_name)


def get_section_comment(comment_title: str, notes: list[str] | None = None) -> str:
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/writers/test_abstract_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,14 @@ def test_save_method():
result = writer.save(overwrite=False)

# Assert the generate_unique_filename was called correctly
mock_unique_filename.assert_called_once_with('test_file', '.py', 'directory')
# mock_unique_filename.assert_called_once_with('test_file', '.py', 'directory')
mock_unique_filename.assert_not_called()

# Assert the file was opened with the correct filename and mode
mock_open.assert_called_once_with('test_file_unique.py', 'w')
mock_open.assert_called_once_with('directory/test_file_latest.py', 'w')

# Assert the correct path is returned
assert result == 'test_file_unique.py'
assert result == ('directory/test_file_latest.py', None)


def test_abstract_file_writer_type_error_on_implementation():
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/writers/test_writer_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import re
import pytest
from unittest.mock import patch
from supabase_pydantic.util.constants import WriterClassType
Expand Down Expand Up @@ -28,7 +30,16 @@ def test_generate_unique_filename():

with patch('os.path.exists', side_effect=lambda x: x.endswith('test.py') or x.endswith('test_1.py')) as mock_exists:
unique_filename = generate_unique_filename(base_name, extension, directory)
assert unique_filename == '/fake/directory/test_2.py'
match = re.search(r'test_(\d{20})\.py$', unique_filename)
assert match is not None, f'Filename does not match the expected pattern: {unique_filename}'

# Validate the datetime format
# Assuming the format is YYYYMMDDHHMMSSffffffffff (year, month, day, hour, minute, second, microsecond)
datetime_str = match.group(1)
try:
datetime.datetime.strptime(datetime_str, '%Y%m%d%H%M%S%f')
except ValueError:
assert False, 'Datetime format is incorrect'


def test_get_section_comment_without_notes():
Expand Down

0 comments on commit 06193cb

Please sign in to comment.