Skip to content

Commit

Permalink
feat: add db-url connector (#37)
Browse files Browse the repository at this point in the history
* feat: add db-url connector
  • Loading branch information
kmbhm1 authored Aug 4, 2024
1 parent d55991d commit 42fe787
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 81 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ $ echo "DB_PORT=<your_db_port>" >> .env # add your postgres db port

## Usage

Generate Pydantic models for FastAPI:
Generate Pydantic models for FastAPI using a local supabase connection:

```bash
$ sb-pydantic gen --type pydantic --framework fastapi --local
Expand All @@ -35,6 +35,19 @@ FastAPI Pydantic models generated successfully: /path/to/your/project/entities/f
File formatted successfully: /path/to/your/project/entities/fastapi/schemas.py
```

Or generate with a url:

```bash
$ sb-pydantic gen --type pydantic --framework fastapi --db-url postgresql://postgres:postgres@127.0.0.1:54322/postgres

Checking local database connection.postgresql://postgres:postgres@127.0.0.1:54322/postgres
Connecting to database: postgres on host: 127.0.0.1 with user: postgres and port: 54322
PostGres connection is open.
Generating FastAPI Pydantic models...
FastAPI Pydantic models generated successfully: /path/to/your/project/entities/fastapi/schemas.py
File formatted successfully: /path/to/your/project/entities/fastapi/schemas.py
```

For some users, integrating a Makefile command may be more convenient:

```bash
Expand Down
1 change: 1 addition & 0 deletions docs/other/references.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

- [Inspiration](https://github.com/yngvem/python-project-structure) for the project structure.
- Github Actions [inspiration](https://endjin.com/blog/2023/02/how-to-implement-continuous-deployment-of-python-packages-with-github-actions) for CI/CD.
- List of [ORMs](https://www.fullstackpython.com/object-relational-mappers-orms.html) & [aerich](https://github.com/tortoise/aerich)

## Testing

Expand Down
4 changes: 3 additions & 1 deletion docs/other/to-do.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
- [x] Integrate mkdocs, mkdocs-material, mkdocstrings, mkdocstrings-python
- [x] Explore security scanners in pipelines
- [ ] Explore SDK integrations, rather than CLI use
- [ ] Test with other conn methods (e.g., supabase secret key)
- [x] Test with other conn methods (e.g., supabase secret key)
- [x] Separate nullable and non-nullable columns in models in a better way
- [ ] Acquire [test dbs](https://github.com/morenoh149/postgresDBSamples) for integration tests
- [x] Finish adding tests for writers and marshalers
Expand All @@ -32,3 +32,5 @@
- [ ] Add to conda
- [ ] Marketing: Submit repo for reviews with Pydantic docs, Supabase docs, Reddit boards, stackoverflow etc.
- [ ] Add fake generator for inserts and seed data
- [ ] Add supabase_secret key connection method
- [ ] Add mysql and other conns ...
69 changes: 42 additions & 27 deletions supabase_pydantic/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pprint
import re
from typing import Any

import click
Expand All @@ -8,11 +9,13 @@
from dotenv import find_dotenv, load_dotenv

from supabase_pydantic.util import (
POSTGRES_SQL_CONN_REGEX,
AppConfig,
DatabaseConnectionType,
FileWriterFactory,
ToolConfig,
clean_directories,
construct_table_info_from_postgres,
construct_tables,
format_with_ruff,
get_standard_jobs,
get_working_directories,
Expand Down Expand Up @@ -126,11 +129,7 @@ def clean(ctx: Any, directory: str) -> None:
# is_flag=True,
# help='Use linked database connection.',
# )
# @connect_sources.option(
# '--dburl',
# type=str,
# help='Use database URL for connection.',
# )

# @connect_sources.option(
# '--project-id',
# type=str,
Expand Down Expand Up @@ -164,6 +163,11 @@ def clean(ctx: Any, directory: str) -> None:
is_flag=True,
help='Use local database connection.',
)
@connect_sources.option(
'--db-url',
type=str,
help='Use database URL for connection.',
)
@click.option(
'-d',
'--dir',
Expand All @@ -182,7 +186,7 @@ def gen(
overwrite: bool,
local: bool = False,
# linked: bool = False,
# dburl: str | None = None,
db_url: str | None = None,
# project_id: str | None = None,
) -> None:
"""Generate models from a PostgreSQL database."""
Expand All @@ -192,33 +196,44 @@ def gen(
# return

# Load environment variables from .env file & check if they are set correctly
if not local:
print('Only local connection is supported at the moment. Exiting...')
if not local and db_url is None:
print('Please provide a connection source. Exiting...')
return

load_dotenv(find_dotenv())
env_vars: dict[str, str | None] = {
'DB_NAME': os.environ.get('DB_NAME', None),
'DB_USER': os.environ.get('DB_USER', None),
'DB_PASS': os.environ.get('DB_PASS', None),
'DB_HOST': os.environ.get('DB_HOST', None),
'DB_PORT': os.environ.get('DB_PORT', None),
}
if any([v is None for v in env_vars.values()]) and local:
print(f'Critical environment variables not set: {", ".join([k for k, v in env_vars.items() if v is None])}.')
print('Using default local values...')
env_vars = local_default_env_configuration()

# Check if environment variables are set correctly
assert check_readiness(env_vars)
conn_type: DatabaseConnectionType = DatabaseConnectionType.LOCAL
env_vars: dict[str, str | None] = dict()
if db_url is not None:
print('Checking local database connection.' + db_url)
if re.match(POSTGRES_SQL_CONN_REGEX, db_url) is None:
print(f'Invalid database URL: "{db_url}". Exiting.')
return
conn_type = DatabaseConnectionType.DB_URL
env_vars['DB_URL'] = db_url
else:
load_dotenv(find_dotenv())
env_vars.update(
**{
'DB_NAME': os.environ.get('DB_NAME', None),
'DB_USER': os.environ.get('DB_USER', None),
'DB_PASS': os.environ.get('DB_PASS', None),
'DB_HOST': os.environ.get('DB_HOST', None),
'DB_PORT': os.environ.get('DB_PORT', None),
}
)
if any([v is None for v in env_vars.values()]) and local:
print(
f'Critical environment variables not set: {", ".join([k for k, v in env_vars.items() if v is None])}.'
)
print('Using default local values...')
env_vars = local_default_env_configuration()
# Check if environment variables are set correctly
assert check_readiness(env_vars)

# Get the directories for the generated files
dirs = get_working_directories(default_directory, frameworks, auto_create=True)

# Get the database schema details
tables = construct_table_info_from_postgres(
env_vars['DB_NAME'], env_vars['DB_USER'], env_vars['DB_PASS'], env_vars['DB_HOST'], env_vars['DB_PORT']
)
tables = construct_tables(conn_type, **env_vars)

# Configure the writer jobs
jobs = {k: v for k, v in get_standard_jobs(models, frameworks, dirs).items() if v.enabled}
Expand Down
8 changes: 6 additions & 2 deletions supabase_pydantic/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
GET_ALL_PUBLIC_TABLES_AND_COLUMNS,
GET_CONSTRAINTS,
GET_TABLE_COLUMN_DETAILS,
POSTGRES_SQL_CONN_REGEX,
PYDANTIC_TYPE_MAP,
AppConfig,
DatabaseConnectionType,
FrameWorkType,
ModelGenerationType,
OrmType,
Expand All @@ -14,7 +16,7 @@
from .dataclasses import AsDictParent, ColumnInfo, ForeignKeyInfo, TableInfo
from .db import (
check_connection,
construct_table_info_from_postgres,
construct_tables,
create_connection,
query_database,
)
Expand All @@ -37,12 +39,14 @@
'AsDictParent',
'ColumnInfo',
'CustomJsonEncoder',
'DatabaseConnectionType',
'FileWriterFactory',
'ForeignKeyInfo',
'FrameWorkType',
'GET_ALL_PUBLIC_TABLES_AND_COLUMNS',
'GET_CONSTRAINTS',
'GET_TABLE_COLUMN_DETAILS',
'POSTGRES_SQL_CONN_REGEX',
'Model',
'ModelGenerationType',
'OrmType',
Expand All @@ -54,7 +58,7 @@
'adapt_type_map',
'check_connection',
'clean_directories',
'construct_table_info_from_postgres',
'construct_tables',
'create_connection',
'create_directories_if_not_exist',
'format_with_ruff',
Expand Down
14 changes: 14 additions & 0 deletions supabase_pydantic/util/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from typing import TypedDict


class DatabaseConnectionType(Enum):
"""Enum for database connection types."""

LOCAL = 'local'
DB_URL = 'db_url'


class AppConfig(TypedDict, total=False):
default_directory: str
overwrite_existing_files: bool
Expand Down Expand Up @@ -313,3 +320,10 @@ class ModelGenerationType(str, Enum):
ORDER BY
conrelid::regclass::text, contype DESC;
"""


# Regex

POSTGRES_SQL_CONN_REGEX = (
r'(postgresql|postgres)://([^:@\s]*(?::[^@\s]*)?@)?(?P<server>[^/\?\s:]+)(:\d+)?(/[^?\s]*)?(\?[^\s]*)?$'
)
121 changes: 86 additions & 35 deletions supabase_pydantic/util/db.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,59 @@
from typing import Any
from typing import Any, Literal
from urllib.parse import urlparse

import psycopg2

from supabase_pydantic.util.constants import (
GET_ALL_PUBLIC_TABLES_AND_COLUMNS,
GET_CONSTRAINTS,
GET_TABLE_COLUMN_DETAILS,
DatabaseConnectionType,
)
from supabase_pydantic.util.dataclasses import TableInfo
from supabase_pydantic.util.exceptions import ConnectionError
from supabase_pydantic.util.marshalers import construct_table_info


def create_connection(dbname: str, user: str, password: str, host: str, port: str) -> Any:
def query_database(conn: Any, query: str) -> Any:
"""Query the database."""
cur = conn.cursor()

try:
cur.execute(query)
result = cur.fetchall()
return result
finally:
cur.close()


def create_connection(dbname: str, username: str, password: str, host: str, port: str) -> Any:
"""Create a connection to the database."""
try:
conn = psycopg2.connect(dbname=dbname, user=username, password=password, host=host, port=port)
return conn
except psycopg2.OperationalError as e:
raise ConnectionError(f'Error connecting to database: {e}')


def create_connection_from_db_url(db_url: str) -> Any:
"""Create a connection to the database."""
conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
return conn
result = urlparse(db_url)
username = result.username
password = result.password
database = result.path[1:]
host = result.hostname
if result.port is None:
raise ConnectionError(f'Invalid database URL port: {db_url}')
port = str(result.port)

assert username is not None, f'Invalid database URL user: {db_url}'
assert password is not None, f'Invalid database URL pass: {db_url}'
assert database is not None, f'Invalid database URL dbname: {db_url}'
assert host is not None, f'Invalid database URL host: {db_url}'

print(f'Connecting to database: {database} on host: {host} with user: {username} and port: {port}')

return create_connection(database, username, password, host, port)


def check_connection(conn: Any) -> bool:
Expand All @@ -26,41 +66,52 @@ def check_connection(conn: Any) -> bool:
return True


def query_database(conn: Any, query: str) -> Any:
"""Query the database."""
cur = conn.cursor()
class DBConnection:
def __init__(self, conn_type: DatabaseConnectionType, **kwargs: Any) -> None:
self.conn_type = conn_type
self.kwargs = kwargs
self.conn = self.create_connection()

try:
cur.execute(query)
result = cur.fetchall()
return result
finally:
cur.close()
def create_connection(self) -> Any:
"""Get the connection to the database."""
if self.conn_type == DatabaseConnectionType.DB_URL:
return create_connection_from_db_url(self.kwargs['DB_URL'])
elif self.conn_type == DatabaseConnectionType.LOCAL:
try:
return create_connection(
self.kwargs['DB_NAME'],
self.kwargs['DB_USER'],
self.kwargs['DB_PASS'],
self.kwargs['DB_HOST'],
self.kwargs['DB_PORT'],
)
except KeyError:
raise ValueError('Invalid connection parameters.')
else:
raise ValueError('Invalid connection type.')

def __enter__(self) -> Any:
return self.conn

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
self.conn.close()
return False


def construct_table_info_from_postgres(
db_name: str | None, user: str | None, password: str | None, host: str | None, port: str | None
) -> Any:
def construct_tables(conn_type: DatabaseConnectionType, **kwargs: Any) -> list[TableInfo]:
"""Construct table information from database."""
# Get Table & Column details from the database
conn: Any = None
try:
# Create a connection to the database & check if connection is successful
assert (
db_name is not None and user is not None and password is not None and host is not None and port is not None
), 'Environment variables not set correctly.'
conn = create_connection(db_name, user, password, host, port)
assert kwargs, 'Invalid or empty connection parameters.'

# Create a connection to the database & check if connection is successful
with DBConnection(conn_type, **kwargs) as conn:
assert check_connection(conn)

# Fetch table column details & foreign key details
column_details = query_database(conn, GET_ALL_PUBLIC_TABLES_AND_COLUMNS)
fk_details = query_database(conn, GET_TABLE_COLUMN_DETAILS)
constraints = query_database(conn, GET_CONSTRAINTS)
try:
# Fetch table column details & foreign key details
column_details = query_database(conn, GET_ALL_PUBLIC_TABLES_AND_COLUMNS)
fk_details = query_database(conn, GET_TABLE_COLUMN_DETAILS)
constraints = query_database(conn, GET_CONSTRAINTS)

return construct_table_info(column_details, fk_details, constraints)
except Exception as e:
raise e
finally:
if conn:
conn.close()
check_connection(conn)
return construct_table_info(column_details, fk_details, constraints)
except Exception as e:
raise e
4 changes: 4 additions & 0 deletions supabase_pydantic/util/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class ConnectionError(Exception):
"""Raised when a connection to the Supabase API cannot be established."""

pass
Loading

0 comments on commit 42fe787

Please sign in to comment.