diff --git a/.changes/unreleased/Fixes-20230612-175854.yaml b/.changes/unreleased/Fixes-20230612-175854.yaml new file mode 100644 index 00000000000..2353d4d8152 --- /dev/null +++ b/.changes/unreleased/Fixes-20230612-175854.yaml @@ -0,0 +1,7 @@ +kind: Fixes +body: Update SemanticModel node to properly impelment the DSI 0.1.0dev3 SemanticModel + protocol spec +time: 2023-06-12T17:58:54.289704-07:00 +custom: + Author: QMalcolm + Issue: 7833 7827 diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 5f3513fbda3..a3a00441a41 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -12,17 +12,21 @@ from dbt.clients.system import write_file from dbt.contracts.files import FileHash -from dbt.contracts.graph.unparsed import ( +from dbt.contracts.graph.semantic_models import ( + Defaults, Dimension, - Docs, Entity, + Measure, + SourceFileMetadata, +) +from dbt.contracts.graph.unparsed import ( + Docs, ExposureType, ExternalTable, FreshnessThreshold, HasYamlMetadata, MacroArgument, MaturityType, - Measure, Owner, Quoting, TestDef, @@ -43,7 +47,11 @@ from dbt.events.contextvars import set_contextvars from dbt.flags import get_flags from dbt.node_types import ModelLanguage, NodeType, AccessType -from dbt_semantic_interfaces.references import MeasureReference +from dbt_semantic_interfaces.references import ( + MeasureReference, + LinkableElementReference, + SemanticModelReference, +) from dbt_semantic_interfaces.references import MetricReference as DSIMetricReference from dbt_semantic_interfaces.type_enums.metric_type import MetricType from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity @@ -554,30 +562,6 @@ def depends_on_macros(self): return self.depends_on.macros -@dataclass -class FileSlice(dbtClassMixin, Replaceable): - """Provides file slice level context about what something was created from. - - Implementation of the dbt-semantic-interfaces `FileSlice` protocol - """ - - filename: str - content: str - start_line_number: int - end_line_number: int - - -@dataclass -class SourceFileMetadata(dbtClassMixin, Replaceable): - """Provides file context about what something was created from. - - Implementation of the dbt-semantic-interfaces `Metadata` protocol - """ - - repo_file_path: str - file_slice: FileSlice - - # ==================================== # CompiledNode subclasses # ==================================== @@ -703,7 +687,6 @@ def same_contract(self, old, adapter_type=None) -> bool: and old_value.constraints != self.columns[old_key].constraints and old.materialization_enforces_constraints ): - for old_constraint in old_value.constraints: if ( old_constraint not in self.columns[old_key].constraints @@ -1493,12 +1476,63 @@ class NodeRelation(dbtClassMixin): @dataclass class SemanticModel(GraphNode): - description: Optional[str] model: str node_relation: Optional[NodeRelation] - entities: Sequence[Entity] - measures: Sequence[Measure] - dimensions: Sequence[Dimension] + description: Optional[str] = None + defaults: Optional[Defaults] = None + entities: Sequence[Entity] = field(default_factory=list) + measures: Sequence[Measure] = field(default_factory=list) + dimensions: Sequence[Dimension] = field(default_factory=list) + metadata: Optional[SourceFileMetadata] = None + + @property + def entity_references(self) -> List[LinkableElementReference]: + return [entity.reference for entity in self.entities] + + @property + def dimension_references(self) -> List[LinkableElementReference]: + return [dimension.reference for dimension in self.dimensions] + + @property + def measure_references(self) -> List[MeasureReference]: + return [measure.reference for measure in self.measures] + + @property + def has_validity_dimensions(self) -> bool: + return any([dim.validity_params is not None for dim in self.dimensions]) + + @property + def validity_start_dimension(self) -> Optional[Dimension]: + validity_start_dims = [ + dim for dim in self.dimensions if dim.validity_params and dim.validity_params.is_start + ] + if not validity_start_dims: + return None + return validity_start_dims[0] + + @property + def validity_end_dimension(self) -> Optional[Dimension]: + validity_end_dims = [ + dim for dim in self.dimensions if dim.validity_params and dim.validity_params.is_end + ] + if not validity_end_dims: + return None + return validity_end_dims[0] + + @property + def partitions(self) -> List[Dimension]: # noqa: D + return [dim for dim in self.dimensions or [] if dim.is_partition] + + @property + def partition(self) -> Optional[Dimension]: + partitions = self.partitions + if not partitions: + return None + return partitions[0] + + @property + def reference(self) -> SemanticModelReference: + return SemanticModelReference(semantic_model_name=self.name) # ==================================== diff --git a/core/dbt/contracts/graph/semantic_models.py b/core/dbt/contracts/graph/semantic_models.py new file mode 100644 index 00000000000..596b8075d49 --- /dev/null +++ b/core/dbt/contracts/graph/semantic_models.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +from dbt.dataclass_schema import dbtClassMixin +from dbt_semantic_interfaces.references import ( + DimensionReference, + EntityReference, + MeasureReference, + TimeDimensionReference, +) +from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType +from dbt_semantic_interfaces.type_enums.dimension_type import DimensionType +from dbt_semantic_interfaces.type_enums.entity_type import EntityType +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing import List, Optional + + +@dataclass +class FileSlice(dbtClassMixin): + """Provides file slice level context about what something was created from. + + Implementation of the dbt-semantic-interfaces `FileSlice` protocol + """ + + filename: str + content: str + start_line_number: int + end_line_number: int + + +@dataclass +class SourceFileMetadata(dbtClassMixin): + """Provides file context about what something was created from. + + Implementation of the dbt-semantic-interfaces `Metadata` protocol + """ + + repo_file_path: str + file_slice: FileSlice + + +@dataclass +class Defaults(dbtClassMixin): + agg_time_dimension: Optional[str] = None + + +# ==================================== +# Dimension objects +# ==================================== + + +@dataclass +class DimensionValidityParams(dbtClassMixin): + is_start: bool = False + is_end: bool = False + + +@dataclass +class DimensionTypeParams(dbtClassMixin): + time_granularity: TimeGranularity + validity_params: Optional[DimensionValidityParams] = None + + +@dataclass +class Dimension(dbtClassMixin): + name: str + type: DimensionType + description: Optional[str] = None + is_partition: bool = False + type_params: Optional[DimensionTypeParams] = None + expr: Optional[str] = None + metadata: Optional[SourceFileMetadata] = None + + @property + def reference(self) -> DimensionReference: + return DimensionReference(element_name=self.name) + + @property + def time_dimension_reference(self) -> Optional[TimeDimensionReference]: + if self.type == DimensionType.TIME: + return TimeDimensionReference(element_name=self.name) + else: + return None + + @property + def validity_params(self) -> Optional[DimensionValidityParams]: + if self.type_params: + return self.type_params.validity_params + else: + return None + + +# ==================================== +# Entity objects +# ==================================== + + +@dataclass +class Entity(dbtClassMixin): + name: str + type: EntityType + description: Optional[str] = None + role: Optional[str] = None + expr: Optional[str] = None + + @property + def reference(self) -> EntityReference: + return EntityReference(element_name=self.name) + + @property + def is_linkable_entity_type(self) -> bool: + return self.type in (EntityType.PRIMARY, EntityType.UNIQUE, EntityType.NATURAL) + + +# ==================================== +# Measure objects +# ==================================== + + +@dataclass +class MeasureAggregationParameters(dbtClassMixin): + percentile: Optional[float] = None + use_discrete_percentile: Optional[bool] = None + use_approximate_percentile: Optional[bool] = None + + +@dataclass +class NonAdditiveDimension(dbtClassMixin): + name: str + window_choice: AggregationType + window_grouples: List[str] + + +@dataclass +class Measure(dbtClassMixin): + name: str + agg: AggregationType + description: Optional[str] = None + create_metric: bool = False + expr: Optional[str] = None + agg_params: Optional[MeasureAggregationParameters] = None + non_additive_dimension: Optional[NonAdditiveDimension] = None + agg_time_dimension: Optional[str] = None + + @property + def checked_agg_time_dimension(self) -> TimeDimensionReference: + if self.agg_time_dimension is not None: + return TimeDimensionReference(element_name=self.agg_time_dimension) + else: + raise Exception("Measure is missing agg_time_dimension!") + + @property + def reference(self) -> MeasureReference: + return MeasureReference(element_name=self.name) diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index eaa9c75e46a..593fd4c5fc1 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -3,6 +3,11 @@ from dbt import deprecations from dbt.node_types import NodeType +from dbt.contracts.graph.semantic_models import ( + Defaults, + DimensionValidityParams, + MeasureAggregationParameters, +) from dbt.contracts.util import ( AdditionalPropertiesMixin, Mergeable, @@ -673,52 +678,58 @@ def validate(cls, data): @dataclass -class Entity(dbtClassMixin): +class UnparsedEntity(dbtClassMixin): name: str - type: str # actually an enum + type: str # EntityType enum description: Optional[str] = None role: Optional[str] = None expr: Optional[str] = None @dataclass -class MeasureAggregationParameters(dbtClassMixin): - percentile: Optional[float] = None - use_discrete_percentile: bool = False - use_approximate_percentile: bool = False +class UnparsedNonAdditiveDimension(dbtClassMixin): + name: str + window_choice: str # AggregationType enum + window_grouples: List[str] @dataclass -class Measure(dbtClassMixin): +class UnparsedMeasure(dbtClassMixin): name: str agg: str # actually an enum description: Optional[str] = None - create_metric: Optional[bool] = None + create_metric: bool = False expr: Optional[str] = None agg_params: Optional[MeasureAggregationParameters] = None - non_additive_dimension: Optional[Dict[str, Any]] = None + non_additive_dimension: Optional[UnparsedNonAdditiveDimension] = None agg_time_dimension: Optional[str] = None @dataclass -class Dimension(dbtClassMixin): +class UnparsedDimensionTypeParams(dbtClassMixin): + time_granularity: str # TimeGranularity enum + validity_params: Optional[DimensionValidityParams] = None + + +@dataclass +class UnparsedDimension(dbtClassMixin): name: str type: str # actually an enum description: Optional[str] = None - is_partition: Optional[bool] = False - type_params: Optional[Dict[str, Any]] = None + is_partition: bool = False + type_params: Optional[UnparsedDimensionTypeParams] = None expr: Optional[str] = None - # TODO metadata: Optional[Metadata] (this would actually be the YML for the dimension) @dataclass class UnparsedSemanticModel(dbtClassMixin): name: str - description: Optional[str] model: str # looks like "ref(...)" - entities: List[Entity] = field(default_factory=list) - measures: List[Measure] = field(default_factory=list) - dimensions: List[Dimension] = field(default_factory=list) + description: Optional[str] = None + defaults: Optional[Defaults] = None + entities: List[UnparsedEntity] = field(default_factory=list) + measures: List[UnparsedMeasure] = field(default_factory=list) + dimensions: List[UnparsedDimension] = field(default_factory=list) def normalize_date(d: Optional[datetime.date]) -> Optional[datetime.datetime]: diff --git a/core/dbt/parser/schema_yaml_readers.py b/core/dbt/parser/schema_yaml_readers.py index 4a343fc9256..2815f83b450 100644 --- a/core/dbt/parser/schema_yaml_readers.py +++ b/core/dbt/parser/schema_yaml_readers.py @@ -2,12 +2,17 @@ from dbt.parser.common import YamlBlock from dbt.node_types import NodeType from dbt.contracts.graph.unparsed import ( + UnparsedDimension, + UnparsedDimensionTypeParams, + UnparsedEntity, UnparsedExposure, UnparsedGroup, + UnparsedMeasure, UnparsedMetric, UnparsedMetricInput, UnparsedMetricInputMeasure, UnparsedMetricTypeParams, + UnparsedNonAdditiveDimension, UnparsedSemanticModel, ) from dbt.contracts.graph.nodes import ( @@ -21,6 +26,13 @@ SemanticModel, WhereFilter, ) +from dbt.contracts.graph.semantic_models import ( + Dimension, + DimensionTypeParams, + Entity, + Measure, + NonAdditiveDimension, +) from dbt.exceptions import DbtInternalError, YamlParseDictError, JSONValidationError from dbt.context.providers import generate_parse_exposure from dbt.contracts.graph.model_config import MetricConfig, ExposureConfig @@ -31,6 +43,9 @@ ) from dbt.clients.jinja import get_rendered from dbt.dataclass_schema import ValidationError +from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType +from dbt_semantic_interfaces.type_enums.dimension_type import DimensionType +from dbt_semantic_interfaces.type_enums.entity_type import EntityType from dbt_semantic_interfaces.type_enums.metric_type import MetricType from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from typing import List, Optional, Union @@ -408,6 +423,79 @@ def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock): self.schema_parser = schema_parser self.yaml = yaml + def _get_dimension_type_params( + self, unparsed: Optional[UnparsedDimensionTypeParams] + ) -> Optional[DimensionTypeParams]: + if unparsed is not None: + return DimensionTypeParams( + time_granularity=TimeGranularity(unparsed.time_granularity), + validity_params=unparsed.validity_params, + ) + else: + return None + + def _get_dimensions(self, unparsed_dimensions: List[UnparsedDimension]) -> List[Dimension]: + dimensions: List[Dimension] = [] + for unparsed in unparsed_dimensions: + dimensions.append( + Dimension( + name=unparsed.name, + type=DimensionType(unparsed.type), + description=unparsed.description, + is_partition=unparsed.is_partition, + type_params=self._get_dimension_type_params(unparsed=unparsed.type_params), + expr=unparsed.expr, + metadata=None, # TODO: requires a fair bit of parsing context + ) + ) + return dimensions + + def _get_entities(self, unparsed_entities: List[UnparsedEntity]) -> List[Entity]: + entities: List[Entity] = [] + for unparsed in unparsed_entities: + entities.append( + Entity( + name=unparsed.name, + type=EntityType(unparsed.type), + description=unparsed.description, + role=unparsed.role, + expr=unparsed.expr, + ) + ) + + return entities + + def _get_non_additive_dimension( + self, unparsed: Optional[UnparsedNonAdditiveDimension] + ) -> Optional[NonAdditiveDimension]: + if unparsed is not None: + return NonAdditiveDimension( + name=unparsed.name, + window_choice=AggregationType(unparsed.window_choice), + window_grouples=unparsed.window_grouples, + ) + else: + return None + + def _get_measures(self, unparsed_measures: List[UnparsedMeasure]) -> List[Measure]: + measures: List[Measure] = [] + for unparsed in unparsed_measures: + measures.append( + Measure( + name=unparsed.name, + agg=AggregationType(unparsed.agg), + description=unparsed.description, + create_metric=unparsed.create_metric, + expr=unparsed.expr, + agg_params=unparsed.agg_params, + non_additive_dimension=self._get_non_additive_dimension( + unparsed.non_additive_dimension + ), + agg_time_dimension=unparsed.agg_time_dimension, + ) + ) + return measures + def parse_semantic_model(self, unparsed: UnparsedSemanticModel): package_name = self.project.project_name unique_id = f"{NodeType.SemanticModel}.{package_name}.{unparsed.name}" @@ -427,9 +515,10 @@ def parse_semantic_model(self, unparsed: UnparsedSemanticModel): path=path, resource_type=NodeType.SemanticModel, unique_id=unique_id, - entities=unparsed.entities, - measures=unparsed.measures, - dimensions=unparsed.dimensions, + entities=self._get_entities(unparsed.entities), + measures=self._get_measures(unparsed.measures), + dimensions=self._get_dimensions(unparsed.dimensions), + defaults=unparsed.defaults, ) self.manifest.add_semantic_model(self.yaml.file, parsed) diff --git a/tests/functional/semantic_models/test_semantic_model_parsing.py b/tests/functional/semantic_models/test_semantic_model_parsing.py index 344e58c0f61..90988d03697 100644 --- a/tests/functional/semantic_models/test_semantic_model_parsing.py +++ b/tests/functional/semantic_models/test_semantic_model_parsing.py @@ -12,6 +12,9 @@ description: This is the revenue semantic model. It should be able to use doc blocks model: ref('fct_revenue') + defaults: + agg_time_dimension: ds + measures: - name: txn_revenue expr: revenue @@ -22,7 +25,6 @@ type: time expr: created_at type_params: - is_primary: True time_granularity: day entities: diff --git a/tests/unit/test_semantic_layer_nodes_satisfy_protocols.py b/tests/unit/test_semantic_layer_nodes_satisfy_protocols.py new file mode 100644 index 00000000000..68a062433af --- /dev/null +++ b/tests/unit/test_semantic_layer_nodes_satisfy_protocols.py @@ -0,0 +1,123 @@ +from dbt.contracts.graph.nodes import ( + Metric, + MetricInputMeasure, + MetricTypeParams, + NodeRelation, + SemanticModel, + WhereFilter, +) +from dbt.contracts.graph.semantic_models import Dimension, DimensionTypeParams, Entity, Measure +from dbt.node_types import NodeType +from dbt_semantic_interfaces.protocols.dimension import Dimension as DSIDimension +from dbt_semantic_interfaces.protocols.entity import Entity as DSIEntitiy +from dbt_semantic_interfaces.protocols.measure import Measure as DSIMeasure +from dbt_semantic_interfaces.protocols.metric import Metric as DSIMetric +from dbt_semantic_interfaces.protocols.semantic_model import SemanticModel as DSISemanticModel +from dbt_semantic_interfaces.type_enums.dimension_type import DimensionType +from dbt_semantic_interfaces.type_enums.entity_type import EntityType +from dbt_semantic_interfaces.type_enums.metric_type import MetricType +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class RuntimeCheckableSemanticModel(DSISemanticModel, Protocol): + pass + + +@runtime_checkable +class RuntimeCheckableDimension(DSIDimension, Protocol): + pass + + +@runtime_checkable +class RuntimeCheckableEntity(DSIEntitiy, Protocol): + pass + + +@runtime_checkable +class RuntimeCheckableMeasure(DSIMeasure, Protocol): + pass + + +@runtime_checkable +class RuntimeCheckableMetric(DSIMetric, Protocol): + pass + + +def test_semantic_model_node_satisfies_protocol(): + test_semantic_model = SemanticModel( + name="test_semantic_model", + description="a test semantic_model", + resource_type=NodeType.SemanticModel, + package_name="package_name", + path="path.to.semantic_model", + original_file_path="path/to/file", + unique_id="not_like_the_other_semantic_models", + fqn=["fully", "qualified", "name"], + model="ref('a_model')", + node_relation=NodeRelation( + alias="test_alias", + schema_name="test_schema_name", + ), + entities=[], + measures=[], + dimensions=[], + ) + assert isinstance(test_semantic_model, RuntimeCheckableSemanticModel) + + +def test_dimension_satisfies_protocol(): + dimension = Dimension( + name="test_dimension", + description="a test dimension", + type=DimensionType.TIME, + type_params=DimensionTypeParams( + time_granularity=TimeGranularity.DAY, + ), + ) + assert isinstance(dimension, RuntimeCheckableDimension) + + +def test_entity_satisfies_protocol(): + entity = Entity( + name="test_entity", + description="a test entity", + type=EntityType.PRIMARY, + expr="id", + role="a_role", + ) + assert isinstance(entity, RuntimeCheckableEntity) + + +def test_measure_satisfies_protocol(): + measure = Measure( + name="test_measure", + description="a test measure", + agg="sum", + create_metric=True, + expr="amount", + agg_time_dimension="a_time_dimension", + ) + assert isinstance(measure, RuntimeCheckableMeasure) + + +def test_metric_node_satisfies_protocol(): + metric = Metric( + name="a_metric", + resource_type=NodeType.Metric, + package_name="package_name", + path="path.to.semantic_model", + original_file_path="path/to/file", + unique_id="not_like_the_other_semantic_models", + fqn=["fully", "qualified", "name"], + description="a test metric", + label="A test metric", + type=MetricType.SIMPLE, + type_params=MetricTypeParams( + measure=MetricInputMeasure( + name="a_test_measure", filter=WhereFilter(where_sql_template="a_dimension is true") + ) + ), + ) + assert isinstance(metric, RuntimeCheckableMetric)