diff --git a/dbt_meshify/storage/yaml_editors.py b/dbt_meshify/storage/yaml_editors.py index bf3b32f..bedb7f5 100644 --- a/dbt_meshify/storage/yaml_editors.py +++ b/dbt_meshify/storage/yaml_editors.py @@ -98,9 +98,9 @@ def add_model_contract_to_yml( # isolate the columns from the existing model entry yml_cols: List[Dict] = model_yml.get("columns", []) catalog_cols = model_catalog.columns or {} if model_catalog else {} + catalog_cols = {k.lower(): v for k, v in catalog_cols.items()} - # add the data type to the yml entry for columns that are in yml - # import pdb; pdb.set_trace() + # add the data type to the yml entry for columns that are in yml yml_cols = [ {**yml_col, "data_type": catalog_cols[yml_col["name"]].type.lower()} for yml_col in yml_cols diff --git a/tests/fixtures.py b/tests/fixtures.py index 887ca6d..f90d9f4 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -25,7 +25,7 @@ "owner": None, }, "columns": { - "id": {"type": "INTEGER", "index": 1, "name": "id", "comment": None}, + "ID": {"type": "INTEGER", "index": 1, "name": "id", "comment": None}, "colleague": {"type": "VARCHAR", "index": 2, "name": "colleague", "comment": None}, }, "stats": { @@ -63,6 +63,17 @@ description: "this is the id column" """ +model_yml_one_col_one_test = """ +models: + - name: shared_model + description: "this is a test model" + columns: + - name: id + description: "this is the id column" + tests: + - unique +""" + model_yml_all_col = """ models: - name: shared_model @@ -102,6 +113,23 @@ data_type: varchar """ +expected_contract_yml_one_col_one_test = """ +models: + - name: shared_model + config: + contract: + enforced: true + description: "this is a test model" + columns: + - name: id + description: "this is the id column" + data_type: integer + tests: + - unique + - name: colleague + data_type: varchar +""" + expected_contract_yml_all_col = """ models: - name: shared_model diff --git a/tests/unit/test_add_contract_to_yml.py b/tests/unit/test_add_contract_to_yml.py index 0a9efe5..ca2d13e 100644 --- a/tests/unit/test_add_contract_to_yml.py +++ b/tests/unit/test_add_contract_to_yml.py @@ -6,10 +6,12 @@ expected_contract_yml_no_entry, expected_contract_yml_one_col, expected_contract_yml_other_model, + expected_contract_yml_one_col_one_test, model_yml_all_col, model_yml_no_col_no_version, model_yml_one_col, model_yml_other_model, + model_yml_one_col_one_test, shared_model_catalog_entry, ) from . import read_yml @@ -36,6 +38,14 @@ def test_add_contract_to_yml_one_col(self): ) assert yml_dict == read_yml(expected_contract_yml_one_col) + def test_add_contract_to_yml_one_col_one_test(self): + yml_dict = meshify.add_model_contract_to_yml( + models_yml=read_yml(model_yml_one_col_one_test), + model_catalog=catalog_entry, + model_name=model_name, + ) + assert yml_dict == read_yml(expected_contract_yml_one_col_one_test) + def test_add_contract_to_yml_all_col(self): yml_dict = meshify.add_model_contract_to_yml( models_yml=read_yml(model_yml_all_col),