Skip to content

Commit

Permalink
Merge pull request #89 from DanCardin/dc/role-env
Browse files Browse the repository at this point in the history
fix: Add role name coercion to postgres default grant to argument.
  • Loading branch information
DanCardin authored Sep 16, 2024
2 parents 652478e + 1d3606b commit bdf9223
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 32 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
# Changelog

## 0.15

### 0.15.0

- fix: Add role name coercion to postgres default grant `to` argument.
- feat: Add ability to supply environment deferred password value to postgres role.

## 0.14

### 0.14.0

- feat: Add basic support for triggers with arguments to Postgres.

## 0.13

### 0.13.0

- feat: Add support for MetaData.drop_all.
- feat: Add basic support for functions and procedures to MySQL.

## 0.12

### 0.12.0

- feat: Add basic support for triggers to MySQL.

## 0.11
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.14.0"
version = "0.15.0"
authors = ["Dan Cardin <ddcardin@gmail.com>"]

description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
Expand Down
9 changes: 8 additions & 1 deletion src/sqlalchemy_declarative_extensions/alembic/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def _compare_roles(autogen_context, upgrade_ops, _):
@renderers.dispatch_for(DropRoleOp)
@renderers.dispatch_for(UpdateRoleOp)
def render_role(autogen_context: AutogenContext, op: CreateRoleOp):
return [f'op.execute("""{command}""")' for command in op.to_sql()]
is_dynamic = op.role.is_dynamic
if is_dynamic:
autogen_context.imports.add("import os")

return [
f'op.execute({"f" if is_dynamic else ""}"""{command}""")'
for command in op.to_sql(raw=False)
]


@Operations.implementation_for(CreateRoleOp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class DefaultGrant:

@classmethod
def on_tables_in_schema(
cls, *in_schemas: str | HasName, for_role: HasName | None = None
cls, *in_schemas: str | HasName, for_role: HasName | str | None = None
) -> DefaultGrant:
schemas = _map_schema_names(*in_schemas)
return cls(
Expand All @@ -118,7 +118,7 @@ def on_tables_in_schema(

@classmethod
def on_sequences_in_schema(
cls, *in_schemas: str | HasName, for_role: HasName | None = None
cls, *in_schemas: str | HasName, for_role: HasName | str | None = None
) -> DefaultGrant:
schemas = _map_schema_names(*in_schemas)
return cls(
Expand All @@ -129,7 +129,7 @@ def on_sequences_in_schema(

@classmethod
def on_types_in_schema(
cls, *in_schemas: str | HasName, for_role: HasName | None = None
cls, *in_schemas: str | HasName, for_role: HasName | str | None = None
) -> DefaultGrant:
schemas = _map_schema_names(*in_schemas)
return cls(
Expand All @@ -140,7 +140,7 @@ def on_types_in_schema(

@classmethod
def on_functions_in_schema(
cls, *in_schemas: str | HasName, for_role: HasName | None = None
cls, *in_schemas: str | HasName, for_role: HasName | str | None = None
) -> DefaultGrant:
schemas = _map_schema_names(*in_schemas)
return cls(
Expand All @@ -149,22 +149,22 @@ def on_functions_in_schema(
target_role=_coerce_name(for_role) if for_role is not None else None,
)

def for_role(self, role: str):
return replace(self, target_role=role)
def for_role(self, role: HasName | str):
return replace(self, target_role=_coerce_name(role))

def grant(
self,
grant: str | G | Grant,
*grants: str | G,
to,
to: HasName | str,
grant_option=False,
):
if not isinstance(grant, Grant):
grant = Grant(
grants=tuple(
_map_grant_names(self.grant_type.to_variants(), grant, *grants)
),
target_role=to,
target_role=_coerce_name(to),
grant_option=grant_option,
)
return DefaultGrantStatement(self, grant)
Expand Down
25 changes: 17 additions & 8 deletions src/sqlalchemy_declarative_extensions/dialects/postgresql/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Role(generic.Role):
connection_limit: int | None = None
valid_until: datetime | None = None

password: str | None = field(default=None, compare=False)
password: generic.Env | str | None = field(default=None, compare=False)

@classmethod
def from_pg_role(cls, r) -> Role:
Expand Down Expand Up @@ -81,15 +81,19 @@ def options(self):

yield f.name, value

@property
def is_dynamic(self) -> bool:
return isinstance(self.password, generic.Env)

def __repr__(self):
cls_name = self.__class__.__name__
options = ", ".join([f"{key}={value!r}" for key, value in self.options])
return f'{cls_name}("{self.name}", {options})'

def to_sql_create(self) -> list[str]:
def to_sql_create(self, raw: bool = True) -> list[str]:
segments = ["CREATE ROLE", self.name]

options = postgres_render_role_options(self)
options = postgres_render_role_options(self, raw=raw)
if options:
segments.append("WITH")
segments.extend(options)
Expand All @@ -102,7 +106,7 @@ def to_sql_create(self) -> list[str]:
command = " ".join(segments) + ";"
return [command]

def to_sql_update(self, to_role: Role) -> list[str]:
def to_sql_update(self, to_role: Role, raw: bool = True) -> list[str]:
role_name = to_role.name
diff = RoleDiff.diff(self, to_role)

Expand All @@ -111,7 +115,7 @@ def to_sql_update(self, to_role: Role) -> list[str]:
if self.use_role:
result.append(f"SET ROLE {self.use_role};")

diff_options = postgres_render_role_options(diff)
diff_options = postgres_render_role_options(diff, raw=raw)
if diff_options:
segments = ["ALTER ROLE", role_name, "WITH", *diff_options]
alter_role = " ".join(segments) + ";"
Expand All @@ -127,7 +131,7 @@ def to_sql_update(self, to_role: Role) -> list[str]:
result.append("RESET ROLE")
return result

def to_sql_drop(self) -> list[str]:
def to_sql_drop(self, raw: bool = True) -> list[str]:
return [f'DROP ROLE "{self.name}";']

def to_sql_use(self, undo: bool) -> list[str]:
Expand Down Expand Up @@ -231,7 +235,7 @@ def conditional_option(option, condition):
return option


def postgres_render_role_options(role: Role | RoleDiff) -> list[str]:
def postgres_render_role_options(role: Role | RoleDiff, raw: bool = False) -> list[str]:
segments = []

if role.superuser is not None:
Expand Down Expand Up @@ -267,7 +271,12 @@ def postgres_render_role_options(role: Role | RoleDiff) -> list[str]:
segments.append(segment)

if isinstance(role, Role) and role.password is not None:
segment = f"PASSWORD {role.password}"
password = (
role.password.resolve(raw=raw)
if isinstance(role.password, generic.Env)
else role.password
)
segment = f"PASSWORD '{password}'"
segments.append(segment)

if role.valid_until is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def kind(self) -> str:
)
return "USER" if is_user else "ROLE"

def to_sql_create(self) -> list[str]:
def to_sql_create(self, raw: bool = True) -> list[str]:
segments = [f"CREATE {self.kind}", self.name]

options = render_role_options(self)
Expand All @@ -143,7 +143,7 @@ def to_sql_create(self) -> list[str]:

return result

def to_sql_update(self, to_role: Role) -> list[str]:
def to_sql_update(self, to_role: Role, raw: bool = True) -> list[str]:
role_name = to_role.name
diff = RoleDiff.diff(self, to_role)

Expand Down
3 changes: 2 additions & 1 deletion src/sqlalchemy_declarative_extensions/role/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sqlalchemy_declarative_extensions.role.generic import Role
from sqlalchemy_declarative_extensions.role.generic import Env, Role

__all__ = [
"Env",
"Role",
]
14 changes: 7 additions & 7 deletions src/sqlalchemy_declarative_extensions/role/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def create_role(cls, operations, role_name: str, **options):
def reverse(self):
return DropRoleOp(self.role)

def to_sql(self) -> list[str]:
def to_sql(self, raw: bool = True) -> list[str]:
role_sql = UseRoleOp.to_sql_from_use_role_ops(self.use_role_ops)
return [*role_sql, *self.role.to_sql_create()]
return [*role_sql, *self.role.to_sql_create(raw=raw)]


@dataclass
Expand All @@ -58,9 +58,9 @@ def update_role(
def reverse(self):
return UpdateRoleOp(from_role=self.role, role=self.from_role)

def to_sql(self):
def to_sql(self, raw: bool = True):
role_sql = UseRoleOp.to_sql_from_use_role_ops(self.use_role_ops)
return [*role_sql, *self.from_role.to_sql_update(self.role)]
return [*role_sql, *self.from_role.to_sql_update(self.role, raw=raw)]


@dataclass
Expand All @@ -76,9 +76,9 @@ def drop_role(cls, operations, role_name: str):
def reverse(self):
return CreateRoleOp(self.role)

def to_sql(self) -> list[str]:
def to_sql(self, raw: bool = True) -> list[str]:
role_sql = UseRoleOp.to_sql_from_use_role_ops(self.use_role_ops)
return [*role_sql, *self.role.to_sql_drop()]
return [*role_sql, *self.role.to_sql_drop(raw=raw)]


@dataclass
Expand All @@ -102,7 +102,7 @@ def to_sql_from_use_role_ops(cls, use_role_ops: list[UseRoleOp] | None):
def reverse(self):
return self

def to_sql(self) -> list[str]:
def to_sql(self, raw: bool = True) -> list[str]:
return self.role.to_sql_use(undo=self.undo)


Expand Down
2 changes: 1 addition & 1 deletion src/sqlalchemy_declarative_extensions/role/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def receive_after_create(metadata: MetaData, connection: Connection, **_):
if not match_name(op.role.name, role_filter):
continue

statements = op.to_sql()
statements = op.to_sql(raw=True)
if isinstance(statements, list):
for statement in statements:
connection.execute(text(statement))
Expand Down
34 changes: 31 additions & 3 deletions src/sqlalchemy_declarative_extensions/role/generic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

import os
from dataclasses import dataclass, replace

from sqlalchemy_declarative_extensions.context import context

__all__ = [
"Role",
"Env",
]


@dataclass(order=True)
class Role:
Expand Down Expand Up @@ -56,6 +62,10 @@ def from_unknown_role(cls, r: Role) -> Role:
def has_option(self):
return False

@property
def is_dynamic(self) -> bool:
return False

@property
def options(self):
yield from []
Expand All @@ -67,19 +77,19 @@ def normalize(self):
use_role=role_name(self.use_role) if self.use_role else None,
)

def to_sql_create(self) -> list[str]:
def to_sql_create(self, raw: bool = True) -> list[str]:
statement = f'CREATE ROLE "{self.name}"'
if self.in_roles is not None:
in_roles = ", ".join(role_names(self.in_roles))
statement += f"IN ROLE {in_roles}"
return [statement + ";"]

def to_sql_update(self, to_role) -> list[str]:
def to_sql_update(self, to_role, raw: bool = True) -> list[str]:
raise NotImplementedError(
"When using the generic role, there should never exist any cause to update a role."
)

def to_sql_drop(self) -> list[str]:
def to_sql_drop(self, raw: bool = True) -> list[str]:
return [f'DROP ROLE "{self.name}";']

def to_sql_use(self, undo: bool) -> list[str]:
Expand All @@ -92,6 +102,24 @@ def __exit__(self, *_):
context.exit_role()


@dataclass
class Env:
"""Provide a way to supply dynamic password variables through the environment at migration time."""

name: str
default: str | None = None

def resolve(self, raw: bool = False):
if raw:
if self.default is not None:
return os.environ.get(self.name, self.default)
return os.environ[self.name]

if self.default is not None:
return f'{{os.environ.get("{self.name}", "{self.default}")}}'
return f'{{os.environ["{self.name}"]}}'


def by_name(r: Role | str) -> str:
if isinstance(r, Role):
return r.name
Expand Down

0 comments on commit bdf9223

Please sign in to comment.