Skip to content

Commit

Permalink
Adap 1049/lazy load agate (#1050)
Browse files Browse the repository at this point in the history
* Add changelog

* Lazy load agate.

* More comments on types and lint.

---------

Co-authored-by: Mila Page <versusfacit@users.noreply.github.com>
  • Loading branch information
VersusFacit and VersusFacit authored Jun 14, 2024
1 parent 944dbea commit 7850da3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240612-195629.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Lazy load agate to improve performance
time: 2024-06-12T19:56:29.943204-07:00
custom:
Author: versusfacit
Issue: "1049"
39 changes: 24 additions & 15 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Callable,
Set,
FrozenSet,
TYPE_CHECKING,
)

from dbt.adapters.base.relation import InformationSchema
Expand All @@ -24,7 +25,10 @@

from typing_extensions import TypeAlias

import agate
if TYPE_CHECKING:
# Indirectly imported via agate_helper, which is lazy loaded further downfile.
# Used by mypy for earlier type hints.
import agate

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport
Expand Down Expand Up @@ -127,34 +131,36 @@ def date_function(cls) -> str:
return "current_timestamp()"

@classmethod
def convert_text_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_text_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "string"

@classmethod
def convert_number_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_number_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
import agate

decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
return "double" if decimals else "bigint"

@classmethod
def convert_integer_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_integer_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "bigint"

@classmethod
def convert_date_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_date_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "date"

@classmethod
def convert_time_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_time_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "time"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: "agate.Table", col_idx: int) -> str:
return "timestamp"

def quote(self, identifier: str) -> str:
return "`{}`".format(identifier)

def _get_relation_information(self, row: agate.Row) -> RelationInfo:
def _get_relation_information(self, row: "agate.Row") -> RelationInfo:
"""relation info was fetched with SHOW TABLES EXTENDED"""
try:
_schema, name, _, information = row
Expand All @@ -165,7 +171,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo:

return _schema, name, information

def _get_relation_information_using_describe(self, row: agate.Row) -> RelationInfo:
def _get_relation_information_using_describe(self, row: "agate.Row") -> RelationInfo:
"""Relation info fetched using SHOW TABLES and an auxiliary DESCRIBE statement"""
try:
_schema, name, _ = row
Expand Down Expand Up @@ -193,8 +199,8 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn

def _build_spark_relation_list(
self,
row_list: agate.Table,
relation_info_func: Callable[[agate.Row], RelationInfo],
row_list: "agate.Table",
relation_info_func: Callable[["agate.Row"], RelationInfo],
) -> List[BaseRelation]:
"""Aggregate relations with format metadata included."""
relations = []
Expand Down Expand Up @@ -370,15 +376,15 @@ def get_catalog(
self,
relation_configs: Iterable[RelationConfig],
used_schemas: FrozenSet[Tuple[str, str]],
) -> Tuple[agate.Table, List[Exception]]:
) -> Tuple["agate.Table", List[Exception]]:
schema_map = self._get_catalog_schemas(relation_configs)
if len(schema_map) > 1:
raise CompilationError(
f"Expected only one database in get_catalog, found " f"{list(schema_map)}"
)

with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = []
futures: List[Future["agate.Table"]] = []
for info, schemas in schema_map.items():
for schema in schemas:
futures.append(
Expand All @@ -399,7 +405,7 @@ def _get_one_catalog(
information_schema: InformationSchema,
schemas: Set[str],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
) -> "agate.Table":
if len(schemas) != 1:
raise CompilationError(
f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}"
Expand All @@ -412,6 +418,9 @@ def _get_one_catalog(
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", str(relation))
columns.extend(self._get_columns_for_catalog(relation))

import agate

return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database: str, schema: str) -> bool:
Expand Down Expand Up @@ -486,7 +495,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]:
"all_purpose_cluster": AllPurposeClusterPythonJobHelper,
}

def standardize_grants_dict(self, grants_table: agate.Table) -> dict:
def standardize_grants_dict(self, grants_table: "agate.Table") -> dict:
grants_dict: Dict[str, List[str]] = {}
for row in grants_table:
grantee = row["Principal"]
Expand Down

0 comments on commit 7850da3

Please sign in to comment.