diff --git a/dbt_meshify/change.py b/dbt_meshify/change.py index 1871805..7c666cc 100644 --- a/dbt_meshify/change.py +++ b/dbt_meshify/change.py @@ -2,7 +2,7 @@ import os from enum import Enum from pathlib import Path -from typing import Dict, Iterable, List, Optional, Protocol +from typing import Callable, Dict, Iterable, List, Optional, Protocol class Operation(str, Enum): @@ -46,6 +46,7 @@ class EntityType(str, Enum): SemanticModel = "semantic_model" Project = "project" Code = "code" + Directory = "directory" def pluralize(self) -> str: if self is self.Analysis: @@ -97,6 +98,14 @@ def __str__(self): ) +@dataclasses.dataclass +class DirectoryChange(BaseChange): + """A DirectoryChange represents a unit of work that should be performed on a Directory in a dbt project.""" + + source: Optional[Path] = None + ignore_function: Optional[Callable] = None + + @dataclasses.dataclass class FileChange(BaseChange): """A FileChange represents a unit of work that should be performed on a File in a dbt project.""" diff --git a/dbt_meshify/change_set_processor.py b/dbt_meshify/change_set_processor.py index 3c5c7d3..30d4de8 100644 --- a/dbt_meshify/change_set_processor.py +++ b/dbt_meshify/change_set_processor.py @@ -4,7 +4,14 @@ from loguru import logger from dbt_meshify.change import Change, ChangeSet, EntityType -from dbt_meshify.storage.file_content_editors import RawFileEditor, ResourceFileEditor +from dbt_meshify.storage.file_content_editors import ( + DirectoryEditor, + RawFileEditor, + ResourceFileEditor, +) + +# Enumeration of valid File Editors +FILE_EDITORS = {EntityType.Code: RawFileEditor, EntityType.Directory: DirectoryEditor} class ChangeSetProcessorException(BaseException): @@ -24,9 +31,7 @@ def __init__(self, dry_run: bool = False) -> None: def write(self, change: Change) -> None: """Commit a Change to the file system.""" - file_editor = ( - RawFileEditor() if change.entity_type == EntityType.Code else ResourceFileEditor() - ) + file_editor = FILE_EDITORS.get(change.entity_type, ResourceFileEditor)() file_editor.__getattribute__(change.operation)(change) diff --git a/dbt_meshify/storage/dbt_project_editors.py b/dbt_meshify/storage/dbt_project_editors.py index d0b444d..187db48 100644 --- a/dbt_meshify/storage/dbt_project_editors.py +++ b/dbt_meshify/storage/dbt_project_editors.py @@ -1,4 +1,6 @@ +import shutil from pathlib import Path +from tracemalloc import start from typing import Dict, Optional, Set from dbt.contracts.graph.nodes import ( @@ -12,9 +14,11 @@ ) from dbt.node_types import AccessType from loguru import logger +from regex import D from dbt_meshify.change import ( ChangeSet, + DirectoryChange, EntityType, FileChange, Operation, @@ -30,6 +34,20 @@ from dbt_meshify.utilities.references import ReferenceUpdater +def get_starter_project_path() -> Path: + """Obtain the path of a dbt starter project on the local filesystem.""" + + from importlib import resources + + import dbt.include.starter_project + + starter_path = Path(str(resources.files(dbt.include.starter_project))) + assert starter_path is not None + assert (starter_path / "dbt_project.yml").exists() + + return starter_path + + class DbtSubprojectCreator: """ Takes a `DbtSubProject` and creates the directory structure and files for it. @@ -85,14 +103,29 @@ def _get_subproject_boundary_models(self) -> Set[str]: ) return boundary_models + def create_starter_project(self) -> DirectoryChange: + """Create a new starter project using the default stored in dbt-core""" + + return DirectoryChange( + operation=Operation.Copy, + entity_type=EntityType.Directory, + identifier=str(self.subproject.path), + path=self.subproject.path, + source=get_starter_project_path(), + ignore_function=shutil.ignore_patterns("__init__.py", "__pycache__", "**/*.pyc"), + ) + def write_project_file(self) -> FileChange: """ Writes the dbt_project.yml file for the subproject in the specified subdirectory """ + + # Read a starter `dbt_project.yml` file as a baseline + starter_path: Path = get_starter_project_path() / "dbt_project.yml" + starter_dbt_project = YAMLFileManager.read_file(starter_path) + contents = self.subproject.project.to_dict() - # was getting a weird serialization error from ruamel on this value - # it's been deprecated, so no reason to keep it - contents.pop("version") + # this one appears in the project yml, but i don't think it should be written contents.pop("query-comment") contents = filter_empty_dict_items(contents) @@ -105,12 +138,22 @@ def write_project_file(self) -> FileChange: if max([len(version) for version in contents["require-dbt-version"]]) == 1: contents["require-dbt-version"] = "".join(contents["require-dbt-version"]) + for key, value in contents.items(): + if value is None: + continue + + if isinstance(value, (list, dict, tuple)): + if len(value) == 0: + continue + + starter_dbt_project[key] = value + return FileChange( operation=Operation.Add, entity_type=EntityType.Code, identifier="dbt_project.yml", path=self.subproject.path / Path("dbt_project.yml"), - data=yaml.dump(contents), + data=yaml.dump(starter_dbt_project), ) def copy_packages_yml_file(self) -> FileChange: @@ -143,6 +186,8 @@ def initialize(self) -> ChangeSet: f"Identifying operations required to split {subproject.name} from {subproject.parent_project.name}." ) + change_set.add(self.create_starter_project()) + for unique_id in ( subproject.resources | subproject.custom_macros diff --git a/dbt_meshify/storage/file_content_editors.py b/dbt_meshify/storage/file_content_editors.py index 5f88c44..e2eb690 100644 --- a/dbt_meshify/storage/file_content_editors.py +++ b/dbt_meshify/storage/file_content_editors.py @@ -2,9 +2,13 @@ from loguru import logger -from dbt_meshify.change import EntityType, FileChange, ResourceChange +from dbt_meshify.change import DirectoryChange, EntityType, FileChange, ResourceChange from dbt_meshify.exceptions import FileEditorException -from dbt_meshify.storage.file_manager import RawFileManager, YAMLFileManager +from dbt_meshify.storage.file_manager import ( + DirectoryManager, + RawFileManager, + YAMLFileManager, +) class NamedList(dict): @@ -89,6 +93,18 @@ def safe_update(original: Dict[Any, Any], update: Dict[Any, Any]) -> Dict[Any, A return original +class DirectoryEditor: + """A helper class used to perform filesystem operations on Directories""" + + @staticmethod + def copy(change: DirectoryChange): + """Copy a file from one location to another.""" + if change.source is None: + raise FileEditorException("None source value provided in Copy operation.") + + DirectoryManager.copy_directory(change.source, change.path, change.ignore_function) + + class RawFileEditor: """A class used to perform Raw operations on Files""" diff --git a/dbt_meshify/storage/file_manager.py b/dbt_meshify/storage/file_manager.py index e845dcd..0a8ad82 100644 --- a/dbt_meshify/storage/file_manager.py +++ b/dbt_meshify/storage/file_manager.py @@ -3,7 +3,7 @@ import shutil from pathlib import Path -from typing import Any, Dict, List, Protocol +from typing import Any, Callable, Dict, List, Optional, Protocol from dbt.contracts.util import Identifier from ruamel.yaml import YAML @@ -18,6 +18,7 @@ def __init__(self): self.preserve_quotes = True self.width = 4096 self.indent(mapping=2, sequence=4, offset=2) + self.default_flow_style = False def dump(self, data, stream=None, **kw): inefficient = False @@ -43,6 +44,21 @@ def write_file(self, path: Path, content: Any) -> None: pass +class DirectoryManager: + """DirectoryManager is a FileManager for operating on directories in the filesystem""" + + @staticmethod + def copy_directory( + source_path: Path, target_path: Path, ignore_function: Optional[Callable] = None + ) -> None: + """Copy a directory from source to target""" + + if not target_path.parent.exists(): + target_path.parent.mkdir(parents=True, exist_ok=True) + + shutil.copytree(source_path, target_path, symlinks=True, ignore=ignore_function) + + class RawFileManager: """RawFileManager is a FileManager for operating on raw files in the filesystem.""" diff --git a/tests/integration/test_split_command.py b/tests/integration/test_split_command.py index 90d5abf..084d80d 100644 --- a/tests/integration/test_split_command.py +++ b/tests/integration/test_split_command.py @@ -9,6 +9,7 @@ src_project_path = "test-projects/split/split_proj" dest_project_path = "test-projects/split/temp_proj" +project_path: Path = Path(dest_project_path) / "my_new_project" @pytest.fixture @@ -34,27 +35,27 @@ def test_split_one_model(self, project): ) assert result.exit_code == 0 - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "customers.sql" - ).exists() + assert (project_path / "models" / "marts" / "customers.sql").exists() x_proj_ref = "{{ ref('split_proj', 'stg_customers') }}" - child_sql = ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "customers.sql" - ).read_text() + child_sql = (project_path / "models" / "marts" / "customers.sql").read_text() assert x_proj_ref in child_sql # Copied a referenced docs block - assert (Path(dest_project_path) / "my_new_project" / "models" / "docs.md").exists() - assert ( - "customer_id" - in (Path(dest_project_path) / "my_new_project" / "models" / "docs.md").read_text() - ) + assert (project_path / "models" / "docs.md").exists() + assert "customer_id" in (project_path / "models" / "docs.md").read_text() # Updated references in a moved macro - redirect_path = Path(dest_project_path) / "my_new_project" / "macros" / "redirect.sql" + redirect_path = project_path / "macros" / "redirect.sql" assert (redirect_path).exists() assert "ref('split_proj', 'orders')" in redirect_path.read_text() + # Starter project assets exist. + assert (project_path / "analyses").exists() + assert (project_path / "macros").exists() + assert (project_path / "seeds").exists() + assert (project_path / "snapshots").exists() + assert (project_path / "tests").exists() + def test_split_one_model_one_source(self, project): runner = CliRunner() result = runner.invoke( @@ -70,19 +71,13 @@ def test_split_one_model_one_source(self, project): ) assert result.exit_code == 0 - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "stg_orders.sql" - ).exists() - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "__sources.yml" - ).exists() + assert (project_path / "models" / "staging" / "stg_orders.sql").exists() + assert (project_path / "models" / "staging" / "__sources.yml").exists() source_yml = yaml.safe_load( (Path(dest_project_path) / "models" / "staging" / "__sources.yml").read_text() ) new_source_yml = yaml.safe_load( - ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "__sources.yml" - ).read_text() + (project_path / "models" / "staging" / "__sources.yml").read_text() ) assert len(source_yml["sources"][0]["tables"]) == 5 @@ -116,35 +111,25 @@ def test_split_one_model_one_source_custom_macro(self, project): / "stg_order_items.sql" ).exists() # moved source - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "__sources.yml" - ).exists() + assert (project_path / "models" / "staging" / "__sources.yml").exists() # copied custom macro - assert ( - Path(dest_project_path) / "my_new_project" / "macros" / "cents_to_dollars.sql" - ).exists() + assert (project_path / "macros" / "cents_to_dollars.sql").exists() # Confirm that we did not bring over an unreferenced dollars_to_cents macro assert ( "dollars_to_cents" - not in ( - Path(dest_project_path) / "my_new_project" / "macros" / "cents_to_dollars.sql" - ).read_text() + not in (project_path / "macros" / "cents_to_dollars.sql").read_text() ) # copied custom macro parents too! - assert ( - Path(dest_project_path) / "my_new_project" / "macros" / "type_numeric.sql" - ).exists() + assert (project_path / "macros" / "type_numeric.sql").exists() # assert only one source moved source_yml = yaml.safe_load( (Path(dest_project_path) / "models" / "staging" / "__sources.yml").read_text() ) new_source_yml = yaml.safe_load( - ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "__sources.yml" - ).read_text() + (project_path / "models" / "staging" / "__sources.yml").read_text() ) assert len(source_yml["sources"][0]["tables"]) == 5 assert len(new_source_yml["sources"][0]["tables"]) == 1 @@ -200,17 +185,13 @@ def test_split_child_project(self): assert result.exit_code == 0 # selected model is moved to subdirectory - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "stg_orders.sql" - ).exists() + assert (project_path / "models" / "staging" / "stg_orders.sql").exists() x_proj_ref = "{{ ref('split_proj', 'stg_order_items') }}" - child_sql = ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "orders.sql" - ).read_text() + child_sql = (project_path / "models" / "marts" / "orders.sql").read_text() # split model is updated to reference the parent project assert x_proj_ref in child_sql # dependencies.yml is moved to subdirectory, not the parent project - assert (Path(dest_project_path) / "my_new_project" / "dependencies.yml").exists() + assert (project_path / "dependencies.yml").exists() assert not (Path(dest_project_path) / "dependencies.yml").exists() teardown_test_project(dest_project_path) @@ -255,16 +236,9 @@ def test_split_public_leaf_nodes(self, project): ) assert result.exit_code == 0 - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "leaf_node.sql" - ).exists() + assert (project_path / "models" / "marts" / "leaf_node.sql").exists() # this is lazy, but this should be the only model in that yml file - assert ( - "access: public" - in ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "__models.yml" - ).read_text() - ) + assert "access: public" in (project_path / "models" / "marts" / "__models.yml").read_text() def test_split_upstream_multiple_boundary_parents(self): """ @@ -287,9 +261,7 @@ def test_split_upstream_multiple_boundary_parents(self): assert result.exit_code == 0 # selected model is moved to subdirectory - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "staging" / "stg_orders.sql" - ).exists() + assert (project_path / "models" / "staging" / "stg_orders.sql").exists() x_proj_ref_1 = "{{ ref('my_new_project', 'stg_orders') }}" x_proj_ref_2 = "{{ ref('my_new_project', 'stg_order_items') }}" child_sql = (Path(dest_project_path) / "models" / "marts" / "orders.sql").read_text() @@ -319,17 +291,13 @@ def test_split_downstream_multiple_boundary_parents(self): assert result.exit_code == 0 # selected model is moved to subdirectory - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "orders.sql" - ).exists() + assert (project_path / "models" / "marts" / "orders.sql").exists() x_proj_ref_1 = "{{ ref('split_proj', 'stg_orders') }}" x_proj_ref_2 = "{{ ref('split_proj', 'stg_order_items') }}" x_proj_ref_3 = "{{ ref('split_proj', 'stg_products') }}" x_proj_ref_4 = "{{ ref('split_proj', 'stg_locations') }}" x_proj_ref_5 = "{{ ref('split_proj', 'stg_supplies') }}" - child_sql = ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "orders.sql" - ).read_text() + child_sql = (project_path / "models" / "marts" / "orders.sql").read_text() # downstream model has all cross project refs updated assert x_proj_ref_1 in child_sql assert x_proj_ref_2 in child_sql @@ -359,9 +327,7 @@ def test_split_updates_exposure_depends_on(self): assert result.exit_code == 0 # selected model is moved to subdirectory - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "orders.sql" - ).exists() + assert (project_path / "models" / "marts" / "orders.sql").exists() old_ref = "ref('orders')" x_proj_ref = "ref('my_new_project', 'orders')" exposure_yml = yaml.safe_load( @@ -397,9 +363,7 @@ def test_split_updates_semantic_model_model(self): assert result.exit_code == 0 # selected model is moved to subdirectory - assert ( - Path(dest_project_path) / "my_new_project" / "models" / "marts" / "orders.sql" - ).exists() + assert (project_path / "models" / "marts" / "orders.sql").exists() old_ref = "ref('orders')" x_proj_ref = "ref('my_new_project', 'orders')" sm_yml = yaml.safe_load(