Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore type checking for functions decorated with @lru_cache #1596

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from __future__ import annotations

import threading
from typing import Dict, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Callable, Dict, Generic, Optional, TypeVar

if TYPE_CHECKING:
# Hack: ensure type checking is not erased for parameters in methods decorated with @lru_cache.
F = TypeVar("F", bound=Callable)

def typed_lru_cache(f: F) -> F: # noqa: D103
pass

else:
from functools import lru_cache as typed_lru_cache # noqa: F401

KeyT = TypeVar("KeyT")
ValueT = TypeVar("ValueT")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import logging
from functools import lru_cache
from typing import Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow_semantics.collection_helpers.lru_cache import typed_lru_cache

DUNDER = "__"

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(
self.date_part = date_part

@staticmethod
@lru_cache
@typed_lru_cache
def from_name(qualified_name: str, custom_granularity_names: Sequence[str]) -> StructuredLinkableSpecName:
"""Construct from a name e.g. listing__ds__month."""
name_parts = qualified_name.split(DUNDER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
Expand All @@ -11,6 +10,7 @@
from typing_extensions import override

from metricflow_semantics.aggregation_properties import AggregationState
from metricflow_semantics.collection_helpers.lru_cache import typed_lru_cache
from metricflow_semantics.model.semantics.linkable_element import ElementPathKey, LinkableElementType
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.specs.dimension_spec import DimensionSpec
Expand Down Expand Up @@ -213,7 +213,7 @@ def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ())
)

@classmethod
@lru_cache
@typed_lru_cache
def _get_compatible_grain_and_date_part(cls) -> Sequence[Tuple[ExpandedTimeGranularity, DatePart]]:
items = []
for date_part in DatePart:
Expand Down
12 changes: 7 additions & 5 deletions metricflow-semantics/metricflow_semantics/time/granularity.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property, lru_cache
from functools import cached_property
from typing import FrozenSet

from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow_semantics.collection_helpers.lru_cache import typed_lru_cache


@dataclass(frozen=True)
class ExpandedTimeGranularity(SerializableDataclass):
Expand Down Expand Up @@ -39,18 +41,18 @@ def is_custom_granularity(self) -> bool: # noqa: D102
return self.base_granularity.value != self.name

@classmethod
@lru_cache
@typed_lru_cache
def from_time_granularity(cls, granularity: TimeGranularity) -> ExpandedTimeGranularity:
"""Factory method for creating an ExpandedTimeGranularity from a standard TimeGranularity enumeration value.

This should be appropriate to use with `@lru_cache` since the number of `TimeGranularity` is small.
This should be appropriate to use with `@typed_lru_cache` since the number of `TimeGranularity` is small.
"""
return ExpandedTimeGranularity(name=granularity.value, base_granularity=granularity)

@classmethod
@lru_cache
@typed_lru_cache
def _standard_time_granularity_names(cls) -> FrozenSet:
"""This should be appropriate to use with `@lru_cache` since the number of `TimeGranularity` is small."""
"""This should be appropriate to use with `@typed_lru_cache` since the number of `TimeGranularity` is small."""
return frozenset(granularity.value for granularity in TimeGranularity)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import logging
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Optional, Sequence

from dbt_semantic_interfaces.implementations.time_spine import PydanticTimeSpineCustomGranularityColumn
from dbt_semantic_interfaces.protocols import SemanticManifest
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow_semantics.collection_helpers.lru_cache import typed_lru_cache
from metricflow_semantics.specs.time_dimension_spec import DEFAULT_TIME_GRANULARITY, TimeDimensionSpec
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
Expand Down Expand Up @@ -78,7 +78,7 @@ def build_standard_time_spine_sources(
return time_spine_sources

@staticmethod
@lru_cache
@typed_lru_cache
def build_custom_time_spine_sources(time_spine_sources: Sequence[TimeSpineSource]) -> Dict[str, TimeSpineSource]:
"""Creates a set of time spine sources with custom granularities based on what's in the manifest."""
return {
Expand Down
Loading