Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squash Migrations #346

Merged
merged 13 commits into from
Aug 7, 2024
2 changes: 2 additions & 0 deletions nodestream/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .scaffold import Scaffold
from .show import Show
from .show_migrations import ShowMigrations
from .squash_migrations import SquashMigration

__all__ = (
"AuditCommand",
Expand All @@ -24,4 +25,5 @@
"Scaffold",
"ShowMigrations",
"Show",
"SquashMigration",
)
31 changes: 31 additions & 0 deletions nodestream/cli/commands/squash_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from cleo.helpers import option

from ..operations import GenerateSquashedMigration
from .nodestream_command import NodestreamCommand
from .shared_options import PROJECT_FILE_OPTION


class SquashMigration(NodestreamCommand):
name = "migrations squash"
description = "Generate a migration for the current project."
options = [
PROJECT_FILE_OPTION,
option(
"from",
description="The name of the migration to squash from.",
value_required=True,
flag=False,
),
option("to", description="The name of the migration to squash to.", flag=False),
]

async def handle_async(self):
from_migration_name = self.option("from")
to_migration_name = self.option("to")
migrations = self.get_migrations()
operation = GenerateSquashedMigration(
migrations,
from_migration_name,
to_migration_name,
)
await self.run_operation(operation)
2 changes: 2 additions & 0 deletions nodestream/cli/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .execute_migration import ExecuteMigrations
from .generate_migration import GenerateMigration
from .generate_pipeline_scaffold import GeneratePipelineScaffold
from .generate_squashed_migration import GenerateSquashedMigration
from .initialize_logger import InitializeLogger
from .initialize_project import InitializeProject
from .operation import Operation
Expand All @@ -20,6 +21,7 @@
"ExecuteMigrations",
"GenerateMigration",
"GeneratePipelineScaffold",
"GenerateSquashedMigration",
"InitializeLogger",
"InitializeProject",
"Operation",
Expand Down
33 changes: 33 additions & 0 deletions nodestream/cli/operations/generate_squashed_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from ...schema.migrations import ProjectMigrations
from .operation import NodestreamCommand, Operation


class GenerateSquashedMigration(Operation):
def __init__(
self,
migrations: ProjectMigrations,
from_migration_name: str,
to_migration_name: str,
) -> None:
self.migrations = migrations
self.from_migration_name = from_migration_name
self.to_migration_name = to_migration_name

async def perform(self, command: NodestreamCommand):
from_migration = self.migrations.graph.get_migration(self.from_migration_name)
to_migration = (
self.migrations.graph.get_migration(self.to_migration_name)
if self.to_migration_name
else None
)
migration, path = self.migrations.create_squash_between(
from_migration, to_migration
)
command.line(f"Generated squashed migration {migration.name}.")
command.line(
f"The migration contains {len(migration.operations)} schema changes."
)
for operation in migration.operations:
command.line(f" - {operation.describe()}")
command.line(f"Migration written to {path}")
command.line("Run `nodestream migrations run` to apply the migration.")
133 changes: 121 additions & 12 deletions nodestream/schema/migrations/migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Optional

from ...file_io import LoadsFromYamlFile, SavesToYamlFile
from .operations import Operation
Expand All @@ -16,12 +16,14 @@ class Migration(LoadsFromYamlFile, SavesToYamlFile):
name: str
operations: List[Operation]
dependencies: List[str]
replaces: List[str] = field(default_factory=list)

def to_file_data(self):
return {
"name": self.name,
"operations": [operation.to_file_data() for operation in self.operations],
"dependencies": self.dependencies,
"replaces": self.replaces,
}

@classmethod
Expand All @@ -33,17 +35,19 @@ def from_file_data(cls, file_data):
for operation_file_data in file_data["operations"]
],
dependencies=file_data["dependencies"],
replaces=file_data.get("replaces", []),
)

@classmethod
def describe_yaml_schema(cls):
from schema import Schema
from schema import Optional, Schema

return Schema(
{
"name": str,
"operations": [Operation.describe_yaml_schema()],
"dependencies": [str],
Optional("replaces"): [str],
}
)

Expand Down Expand Up @@ -74,6 +78,53 @@ def write_to_file_with_default_name(self, directory: Path) -> Path:
self.write_to_file(path)
return path

def is_squashed_migration(self) -> bool:
"""Check if this migration is a squashed migration.

Returns:
True if this migration is a squashed migration, False otherwise.
"""
return len(self.replaces) > 0

@classmethod
def squash(
cls,
new_name: str,
migrations: Iterable["Migration"],
optimize_operations: bool = True,
) -> "Migration":
"""Make a new migration as the squashed for the given migrations.

Args:
new_name: The name of the new migration.
migrations: The migrations to squash.
optimize_operations: Whether to optimize the operations
before squashing.

Returns:
The new migration that is the effective squash of the
given migrations.
"""
names_of_migrations_being_squashed = {m.name for m in migrations}
effective_operations = [o for m in migrations for o in m.operations]
dependencies = list(
{
d
for m in migrations
for d in m.dependencies
if d not in names_of_migrations_being_squashed
}
)
if optimize_operations:
effective_operations = Operation.optimize(effective_operations)

return cls(
name=new_name,
operations=effective_operations,
replaces=list(names_of_migrations_being_squashed),
dependencies=dependencies,
)


@dataclass(frozen=True, slots=True)
class MigrationGraph:
Expand All @@ -98,20 +149,56 @@ def get_migration(self, name: str) -> Migration:
"""
return self.migrations_by_name[name]

def get_ordered_migration_plan(self) -> List[Migration]:
def get_ordered_migration_plan(
self, completed_migrations: List[Migration]
) -> List[Migration]:
completed_migration_names = {m.name for m in completed_migrations}
replacement_index = {r: m for m in self.all_migrations() for r in m.replaces}
plan_order = []

for migration in self.all_migrations():
for required_migration in self.plan_to_execute_migration(migration):
if required_migration not in plan_order:
plan_order.append(required_migration)
for migration in self.topological_order():
if migration.name in completed_migration_names:
continue

# If we are considering a migration that has been replaced,
# we _only_ want to add it if at least one of the migrations
# replaced by the replacing migration has not been completed.
# IN otherwords, if the changes from a squashed migration have
# have (at least partially) been applied, we don't want to add
# the squashed migration to the plan but will want to add the
# migrations that were replaced to the plan.
if (replacement := replacement_index.get(migration.name)) is not None:
if not any(
r in completed_migration_names for r in replacement.replaces
):
continue

# Similarly, if we are looking at a squashed migration, we want to
# add it to the plan only if none of the migrations that it replaces
# have been completed.
if migration.is_squashed_migration():
if any(r in completed_migration_names for r in migration.replaces):
continue

# Now here we are in one of three stats:
#
# 1. The migraiton is some "regular" migration that has not been
# completed.
# 2. The migration is a squashed migration that we want to add to
# the plan.
# 3. The migration is a migration that has been replaced by a
# squashed migration but we've determined that at least one
# of the migrations that it replaces has not been completed.
#
# In all of these cases, we want to add the migration to the plan.
plan_order.append(migration)

return plan_order

def _iterative_dfs_traversal(self, start_node: Migration) -> List[Migration]:
def _iterative_dfs_traversal(self, *start_node: Migration) -> List[Migration]:
visited_order = []
visited_set = set()
stack = [(start_node, False)]
stack = [(n, False) for n in start_node]

while stack:
node, processed = stack.pop()
Expand All @@ -129,8 +216,30 @@ def _iterative_dfs_traversal(self, start_node: Migration) -> List[Migration]:

return visited_order

def plan_to_execute_migration(self, migration: Migration) -> List[Migration]:
return self._iterative_dfs_traversal(migration)
def squash_between(
self,
name: str,
from_migration: Migration,
to_migration: Optional[Migration] = None,
):
"""Squash all migrations between two migrations.

Args:
name: The name of the new squashed migration.
from_migration: The migration to start squashing from.
to_migration: The migration to stop squashing at.

Returns:
The new squashed migration.
"""
ordered = self.topological_order()
from_index = ordered.index(from_migration)
to_index = ordered.index(to_migration) if to_migration else len(ordered) - 1
migrations_to_squash = ordered[from_index : to_index + 1]
return Migration.squash(name, migrations_to_squash)

def topological_order(self):
return self._iterative_dfs_traversal(*self.get_leaf_migrations())

def all_migrations(self) -> Iterable[Migration]:
"""Iterate over all migrations in the graph."""
Expand Down
Loading