Skip to content

Commit

Permalink
Merge pull request #346 from nodestream-proj/feature/squash-migrations
Browse files Browse the repository at this point in the history
Squash Migrations
  • Loading branch information
zprobst authored Aug 7, 2024
2 parents dc20606 + df0ad7c commit db2592d
Show file tree
Hide file tree
Showing 13 changed files with 918 additions and 17 deletions.
2 changes: 2 additions & 0 deletions nodestream/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .scaffold import Scaffold
from .show import Show
from .show_migrations import ShowMigrations
from .squash_migrations import SquashMigration

__all__ = (
"AuditCommand",
Expand All @@ -24,4 +25,5 @@
"Scaffold",
"ShowMigrations",
"Show",
"SquashMigration",
)
31 changes: 31 additions & 0 deletions nodestream/cli/commands/squash_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from cleo.helpers import option

from ..operations import GenerateSquashedMigration
from .nodestream_command import NodestreamCommand
from .shared_options import PROJECT_FILE_OPTION


class SquashMigration(NodestreamCommand):
name = "migrations squash"
description = "Generate a migration for the current project."
options = [
PROJECT_FILE_OPTION,
option(
"from",
description="The name of the migration to squash from.",
value_required=True,
flag=False,
),
option("to", description="The name of the migration to squash to.", flag=False),
]

async def handle_async(self):
from_migration_name = self.option("from")
to_migration_name = self.option("to")
migrations = self.get_migrations()
operation = GenerateSquashedMigration(
migrations,
from_migration_name,
to_migration_name,
)
await self.run_operation(operation)
2 changes: 2 additions & 0 deletions nodestream/cli/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .execute_migration import ExecuteMigrations
from .generate_migration import GenerateMigration
from .generate_pipeline_scaffold import GeneratePipelineScaffold
from .generate_squashed_migration import GenerateSquashedMigration
from .initialize_logger import InitializeLogger
from .initialize_project import InitializeProject
from .operation import Operation
Expand All @@ -20,6 +21,7 @@
"ExecuteMigrations",
"GenerateMigration",
"GeneratePipelineScaffold",
"GenerateSquashedMigration",
"InitializeLogger",
"InitializeProject",
"Operation",
Expand Down
33 changes: 33 additions & 0 deletions nodestream/cli/operations/generate_squashed_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from ...schema.migrations import ProjectMigrations
from .operation import NodestreamCommand, Operation


class GenerateSquashedMigration(Operation):
def __init__(
self,
migrations: ProjectMigrations,
from_migration_name: str,
to_migration_name: str,
) -> None:
self.migrations = migrations
self.from_migration_name = from_migration_name
self.to_migration_name = to_migration_name

async def perform(self, command: NodestreamCommand):
from_migration = self.migrations.graph.get_migration(self.from_migration_name)
to_migration = (
self.migrations.graph.get_migration(self.to_migration_name)
if self.to_migration_name
else None
)
migration, path = self.migrations.create_squash_between(
from_migration, to_migration
)
command.line(f"Generated squashed migration {migration.name}.")
command.line(
f"The migration contains {len(migration.operations)} schema changes."
)
for operation in migration.operations:
command.line(f" - {operation.describe()}")
command.line(f"Migration written to {path}")
command.line("Run `nodestream migrations run` to apply the migration.")
133 changes: 121 additions & 12 deletions nodestream/schema/migrations/migrations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List
from typing import Dict, Iterable, List, Optional

from ...file_io import LoadsFromYamlFile, SavesToYamlFile
from .operations import Operation
Expand All @@ -16,12 +16,14 @@ class Migration(LoadsFromYamlFile, SavesToYamlFile):
name: str
operations: List[Operation]
dependencies: List[str]
replaces: List[str] = field(default_factory=list)

def to_file_data(self):
return {
"name": self.name,
"operations": [operation.to_file_data() for operation in self.operations],
"dependencies": self.dependencies,
"replaces": self.replaces,
}

@classmethod
Expand All @@ -33,17 +35,19 @@ def from_file_data(cls, file_data):
for operation_file_data in file_data["operations"]
],
dependencies=file_data["dependencies"],
replaces=file_data.get("replaces", []),
)

@classmethod
def describe_yaml_schema(cls):
from schema import Schema
from schema import Optional, Schema

return Schema(
{
"name": str,
"operations": [Operation.describe_yaml_schema()],
"dependencies": [str],
Optional("replaces"): [str],
}
)

Expand Down Expand Up @@ -74,6 +78,53 @@ def write_to_file_with_default_name(self, directory: Path) -> Path:
self.write_to_file(path)
return path

def is_squashed_migration(self) -> bool:
"""Check if this migration is a squashed migration.
Returns:
True if this migration is a squashed migration, False otherwise.
"""
return len(self.replaces) > 0

@classmethod
def squash(
cls,
new_name: str,
migrations: Iterable["Migration"],
optimize_operations: bool = True,
) -> "Migration":
"""Make a new migration as the squashed for the given migrations.
Args:
new_name: The name of the new migration.
migrations: The migrations to squash.
optimize_operations: Whether to optimize the operations
before squashing.
Returns:
The new migration that is the effective squash of the
given migrations.
"""
names_of_migrations_being_squashed = {m.name for m in migrations}
effective_operations = [o for m in migrations for o in m.operations]
dependencies = list(
{
d
for m in migrations
for d in m.dependencies
if d not in names_of_migrations_being_squashed
}
)
if optimize_operations:
effective_operations = Operation.optimize(effective_operations)

return cls(
name=new_name,
operations=effective_operations,
replaces=list(names_of_migrations_being_squashed),
dependencies=dependencies,
)


@dataclass(frozen=True, slots=True)
class MigrationGraph:
Expand All @@ -98,20 +149,56 @@ def get_migration(self, name: str) -> Migration:
"""
return self.migrations_by_name[name]

def get_ordered_migration_plan(self) -> List[Migration]:
def get_ordered_migration_plan(
self, completed_migrations: List[Migration]
) -> List[Migration]:
completed_migration_names = {m.name for m in completed_migrations}
replacement_index = {r: m for m in self.all_migrations() for r in m.replaces}
plan_order = []

for migration in self.all_migrations():
for required_migration in self.plan_to_execute_migration(migration):
if required_migration not in plan_order:
plan_order.append(required_migration)
for migration in self.topological_order():
if migration.name in completed_migration_names:
continue

# If we are considering a migration that has been replaced,
# we _only_ want to add it if at least one of the migrations
# replaced by the replacing migration has not been completed.
# IN otherwords, if the changes from a squashed migration have
# have (at least partially) been applied, we don't want to add
# the squashed migration to the plan but will want to add the
# migrations that were replaced to the plan.
if (replacement := replacement_index.get(migration.name)) is not None:
if not any(
r in completed_migration_names for r in replacement.replaces
):
continue

# Similarly, if we are looking at a squashed migration, we want to
# add it to the plan only if none of the migrations that it replaces
# have been completed.
if migration.is_squashed_migration():
if any(r in completed_migration_names for r in migration.replaces):
continue

# Now here we are in one of three stats:
#
# 1. The migraiton is some "regular" migration that has not been
# completed.
# 2. The migration is a squashed migration that we want to add to
# the plan.
# 3. The migration is a migration that has been replaced by a
# squashed migration but we've determined that at least one
# of the migrations that it replaces has not been completed.
#
# In all of these cases, we want to add the migration to the plan.
plan_order.append(migration)

return plan_order

def _iterative_dfs_traversal(self, start_node: Migration) -> List[Migration]:
def _iterative_dfs_traversal(self, *start_node: Migration) -> List[Migration]:
visited_order = []
visited_set = set()
stack = [(start_node, False)]
stack = [(n, False) for n in start_node]

while stack:
node, processed = stack.pop()
Expand All @@ -129,8 +216,30 @@ def _iterative_dfs_traversal(self, start_node: Migration) -> List[Migration]:

return visited_order

def plan_to_execute_migration(self, migration: Migration) -> List[Migration]:
return self._iterative_dfs_traversal(migration)
def squash_between(
self,
name: str,
from_migration: Migration,
to_migration: Optional[Migration] = None,
):
"""Squash all migrations between two migrations.
Args:
name: The name of the new squashed migration.
from_migration: The migration to start squashing from.
to_migration: The migration to stop squashing at.
Returns:
The new squashed migration.
"""
ordered = self.topological_order()
from_index = ordered.index(from_migration)
to_index = ordered.index(to_migration) if to_migration else len(ordered) - 1
migrations_to_squash = ordered[from_index : to_index + 1]
return Migration.squash(name, migrations_to_squash)

def topological_order(self):
return self._iterative_dfs_traversal(*self.get_leaf_migrations())

def all_migrations(self) -> Iterable[Migration]:
"""Iterate over all migrations in the graph."""
Expand Down
Loading

0 comments on commit db2592d

Please sign in to comment.