diff --git a/dbt_meshify/storage/dbt_project_editors.py b/dbt_meshify/storage/dbt_project_editors.py index 2b45c91..d0b444d 100644 --- a/dbt_meshify/storage/dbt_project_editors.py +++ b/dbt_meshify/storage/dbt_project_editors.py @@ -4,6 +4,7 @@ from dbt.contracts.graph.nodes import ( CompiledNode, GenericTestNode, + Macro, ModelNode, NodeType, Resource, @@ -200,6 +201,11 @@ def initialize(self) -> ChangeSet: ) ) + if isinstance(resource, Macro) and resource.macro_sql: + changes = reference_updater.update_macro_refs(resource) + if changes: + change_set.add(changes) + else: change_set.add(self.copy_resource(resource)) diff --git a/dbt_meshify/utilities/references.py b/dbt_meshify/utilities/references.py index d6ccadb..18f69e8 100644 --- a/dbt_meshify/utilities/references.py +++ b/dbt_meshify/utilities/references.py @@ -2,7 +2,13 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from dbt.contracts.graph.nodes import CompiledNode, Exposure, Resource, SemanticModel +from dbt.contracts.graph.nodes import ( + CompiledNode, + Exposure, + Macro, + Resource, + SemanticModel, +) from loguru import logger from dbt_meshify.change import ( @@ -167,16 +173,16 @@ def update_yml_resource_references( def generate_reference_update( self, project_name: str, - upstream_node: CompiledNode, - downstream_node: Union[Resource, CompiledNode], + upstream_node: Union[CompiledNode, Macro], + downstream_node: Union[Resource, CompiledNode, Macro], downstream_project: PathedProject, code: str, ) -> Union[FileChange, ResourceChange]: """Generate FileChanges that update the references in the downstream_node's code.""" - change: Union[FileChange, ResourceChange] - if isinstance(downstream_node, CompiledNode): - updated_code = self.ref_update_methods[downstream_node.language]( + if isinstance(downstream_node, CompiledNode) or isinstance(downstream_node, Macro): + language = downstream_node.language if hasattr(downstream_node, "language") else "sql" + updated_code = self.ref_update_methods[language]( model_name=upstream_node.name, project_name=project_name, model_code=code, @@ -206,6 +212,52 @@ def generate_reference_update( ) raise Exception("Invalid node type provided to generate_reference_update.") + def update_macro_refs( + self, + resource: Macro, + ) -> Optional[FileChange]: + """Generate a set of FileChanges to update references within a macro""" + + # If we're accidentally operating in a root project, do not rewrite anything. + if isinstance(self.project, DbtProject): + return None + + pattern = re.compile(r"{{\s*ref\s*\(\s*['\"]([a-zA-Z_]+)['\"](?:,\s*(v=\d+))?\s*\)\s*}}") + matches = re.search(pattern, resource.macro_sql) + + if matches is None: + return None + + model_name = matches.group(1) + model_id = f"model.{self.project.parent_project.name}.{model_name}" + + # Check if the model is part of the original project. + if model_id not in self.project.parent_project.resources: + raise ValueError( + f"Unable to find {model_id} in the parent project. How did we get here?" + ) + + # Check if the model is part of the new project. If it is, then return None. + # We don't want to rewrite the ref for a model already in this project.. + if model_id in self.project.resources: + logger.debug(f"Model {model_id} exists in the new project! Skipping.") + return None + + change = self.generate_reference_update( + project_name=self.project.parent_project.name, + upstream_node=self.project.parent_project.models[model_id], + downstream_node=resource, + code=resource.macro_sql, + downstream_project=self.project, + ) + + if not isinstance(change, FileChange): + raise TypeError( + "Incorrect change type returned during macro ref updating. Expected FileChange, but found ResourceChange" + ) + + return change + def update_child_refs( self, resource: CompiledNode, diff --git a/test-projects/split/split_proj/macros/redirect.sql b/test-projects/split/split_proj/macros/redirect.sql new file mode 100644 index 0000000..78e50fd --- /dev/null +++ b/test-projects/split/split_proj/macros/redirect.sql @@ -0,0 +1,3 @@ +{% macro redirect() %} + {{ ref('orders') }} +{% endmacro %} \ No newline at end of file diff --git a/test-projects/split/split_proj/models/marts/customers.sql b/test-projects/split/split_proj/models/marts/customers.sql index e2cf5b4..af4f43e 100644 --- a/test-projects/split/split_proj/models/marts/customers.sql +++ b/test-projects/split/split_proj/models/marts/customers.sql @@ -14,7 +14,8 @@ customers as ( orders_mart as ( - select * from {{ ref('orders') }} + -- Redirect is a proxy for `ref`, used to test ref rewrites in macros. + select * from {{ redirect() }} ), diff --git a/tests/integration/test_split_command.py b/tests/integration/test_split_command.py index c2fcf92..90d5abf 100644 --- a/tests/integration/test_split_command.py +++ b/tests/integration/test_split_command.py @@ -50,6 +50,11 @@ def test_split_one_model(self, project): in (Path(dest_project_path) / "my_new_project" / "models" / "docs.md").read_text() ) + # Updated references in a moved macro + redirect_path = Path(dest_project_path) / "my_new_project" / "macros" / "redirect.sql" + assert (redirect_path).exists() + assert "ref('split_proj', 'orders')" in redirect_path.read_text() + def test_split_one_model_one_source(self, project): runner = CliRunner() result = runner.invoke(