From 0d4d470dace1f35120bb94a06cf9eefcc74cfc4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Thu, 8 Aug 2024 09:45:59 +0000 Subject: [PATCH] [DOP-18571] Collect and log Spark metrics in various method calls --- docs/changelog/next_release/303.feature.1.rst | 1 + docs/changelog/next_release/303.feature.2.rst | 10 +++ onetl/_util/spark.py | 17 ++++- onetl/base/base_db_connection.py | 3 +- onetl/base/base_file_df_connection.py | 4 +- .../db_connection/hive/connection.py | 49 +++++++++++-- .../jdbc_connection/connection.py | 8 ++- .../db_connection/jdbc_mixin/connection.py | 69 +++++++++++-------- .../db_connection/oracle/connection.py | 39 +++-------- onetl/db/db_writer/db_writer.py | 47 ++++++++++--- onetl/file/file_df_writer/file_df_writer.py | 40 +++++++++-- .../test_postgres_integration.py | 4 +- 12 files changed, 207 insertions(+), 84 deletions(-) create mode 100644 docs/changelog/next_release/303.feature.1.rst create mode 100644 docs/changelog/next_release/303.feature.2.rst diff --git a/docs/changelog/next_release/303.feature.1.rst b/docs/changelog/next_release/303.feature.1.rst new file mode 100644 index 000000000..8c0b1e19e --- /dev/null +++ b/docs/changelog/next_release/303.feature.1.rst @@ -0,0 +1 @@ +Log estimated size of in-memory dataframe created by ``JDBC.fetch`` and ``JDBC.execute`` methods. diff --git a/docs/changelog/next_release/303.feature.2.rst b/docs/changelog/next_release/303.feature.2.rst new file mode 100644 index 000000000..92bbe13c3 --- /dev/null +++ b/docs/changelog/next_release/303.feature.2.rst @@ -0,0 +1,10 @@ +Collect Spark execution metrics in following methods, and log then in DEBUG mode: +* ``DBWriter.run()`` +* ``FileDFWriter.run()`` +* ``Hive.sql()`` +* ``Hive.execute()`` + +This is implemented using custom ``SparkListener`` which wraps the entire method call, and +then report collected metrics. But these metrics sometimes may be missing due to Spark architecture, +so they are not reliable source of information. That's why logs are printed only in DEBUG mode, and +are not returned as method call result. diff --git a/onetl/_util/spark.py b/onetl/_util/spark.py index f172b1c98..2b2edbaf9 100644 --- a/onetl/_util/spark.py +++ b/onetl/_util/spark.py @@ -16,7 +16,7 @@ from pydantic import SecretStr # type: ignore[no-redef, assignment] if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import DataFrame, SparkSession from pyspark.sql.conf import RuntimeConfig @@ -136,6 +136,21 @@ def get_spark_version(spark_session: SparkSession) -> Version: return Version(spark_session.version) +def estimate_dataframe_size(spark_session: SparkSession, df: DataFrame) -> int: + """ + Estimate in-memory DataFrame size in bytes. If cannot be estimated, return 0. + + Using Spark's `SizeEstimator `_. + """ + try: + size_estimator = spark_session._jvm.org.apache.spark.util.SizeEstimator # type: ignore[union-attr] + return size_estimator.estimate(df._jdf) + except Exception: + # SizeEstimator uses Java reflection which may behave differently in different Java versions, + # and also may be prohibited. + return 0 + + def get_executor_total_cores(spark_session: SparkSession, include_driver: bool = False) -> tuple[int | float, dict]: """ Calculate maximum number of cores which can be used by Spark on all executors. diff --git a/onetl/base/base_db_connection.py b/onetl/base/base_db_connection.py index f9c7bcac0..2c427debd 100644 --- a/onetl/base/base_db_connection.py +++ b/onetl/base/base_db_connection.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from etl_entities.hwm import HWM - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructField, StructType @@ -106,6 +106,7 @@ class BaseDBConnection(BaseConnection): Implements generic methods for reading and writing dataframe from/to database-like source """ + spark: SparkSession Dialect = BaseDBDialect @property diff --git a/onetl/base/base_file_df_connection.py b/onetl/base/base_file_df_connection.py index c54390ce8..28c57f3c7 100644 --- a/onetl/base/base_file_df_connection.py +++ b/onetl/base/base_file_df_connection.py @@ -11,7 +11,7 @@ from onetl.base.pure_path_protocol import PurePathProtocol if TYPE_CHECKING: - from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter + from pyspark.sql import DataFrame, DataFrameReader, DataFrameWriter, SparkSession from pyspark.sql.types import StructType @@ -72,6 +72,8 @@ class BaseFileDFConnection(BaseConnection): .. versionadded:: 0.9.0 """ + spark: SparkSession + @abstractmethod def check_if_format_supported( self, diff --git a/onetl/connection/db_connection/hive/connection.py b/onetl/connection/db_connection/hive/connection.py index 81c50e87e..f057e846b 100644 --- a/onetl/connection/db_connection/hive/connection.py +++ b/onetl/connection/db_connection/hive/connection.py @@ -13,6 +13,7 @@ except (ImportError, AttributeError): from pydantic import validator # type: ignore[no-redef, assignment] +from onetl._metrics.recorder import SparkMetricsRecorder from onetl._util.spark import inject_spark_param from onetl._util.sql import clear_statement from onetl.connection.db_connection.db_connection import DBConnection @@ -210,8 +211,29 @@ def sql( log.info("|%s| Executing SQL query:", self.__class__.__name__) log_lines(log, query) - df = self._execute_sql(query) - log.info("|Spark| DataFrame successfully created from SQL statement") + with SparkMetricsRecorder(self.spark) as recorder: + try: + df = self._execute_sql(query) + except Exception: + log.error("|%s| Query failed", self.__class__.__name__) + + metrics = recorder.metrics() + if not metrics.is_empty and log.isEnabledFor(logging.DEBUG): + # as SparkListener results are not guaranteed to be received in time, + # some metrics may be missing. To avoid confusion, log only in debug, and with a notice + log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) + raise + + log.info("|Spark| DataFrame successfully created from SQL statement") + + metrics = recorder.metrics() + if not metrics.is_empty and log.isEnabledFor(logging.DEBUG): + # as SparkListener results are not guaranteed to be received in time, + # some metrics may be missing. To avoid confusion, log only in debug, and with a notice + log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) + return df @slot @@ -236,8 +258,27 @@ def execute( log.info("|%s| Executing statement:", self.__class__.__name__) log_lines(log, statement) - self._execute_sql(statement).collect() - log.info("|%s| Call succeeded", self.__class__.__name__) + with SparkMetricsRecorder(self.spark) as recorder: + try: + self._execute_sql(statement).collect() + except Exception: + log.error("|%s| Execution failed", self.__class__.__name__) + metrics = recorder.metrics() + if not metrics.is_empty and log.isEnabledFor(logging.DEBUG): + # as SparkListener results are not guaranteed to be received in time, + # some metrics may be missing. To avoid confusion, log only in debug, and with a notice + log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) + raise + + log.info("|%s| Execution succeeded", self.__class__.__name__) + + metrics = recorder.metrics() + if not metrics.is_empty and log.isEnabledFor(logging.DEBUG): + # as SparkListener results are not guaranteed to be received in time, + # some metrics may be missing. To avoid confusion, log only in debug, and with a notice + log.info("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) @slot def write_df_to_target( diff --git a/onetl/connection/db_connection/jdbc_connection/connection.py b/onetl/connection/db_connection/jdbc_connection/connection.py index 5b0aebeb8..2fc2f7cfa 100644 --- a/onetl/connection/db_connection/jdbc_connection/connection.py +++ b/onetl/connection/db_connection/jdbc_connection/connection.py @@ -92,9 +92,13 @@ def sql( log.info("|%s| Executing SQL query (on executor):", self.__class__.__name__) log_lines(log, query) - df = self._query_on_executor(query, self.SQLOptions.parse(options)) + try: + df = self._query_on_executor(query, self.SQLOptions.parse(options)) + except Exception: + log.error("|%s| Query failed!", self.__class__.__name__) + raise - log.info("|Spark| DataFrame successfully created from SQL statement ") + log.info("|Spark| DataFrame successfully created from SQL statement") return df @slot diff --git a/onetl/connection/db_connection/jdbc_mixin/connection.py b/onetl/connection/db_connection/jdbc_mixin/connection.py index e8c19e38b..84276147a 100644 --- a/onetl/connection/db_connection/jdbc_mixin/connection.py +++ b/onetl/connection/db_connection/jdbc_mixin/connection.py @@ -9,15 +9,14 @@ from enum import Enum, auto from typing import TYPE_CHECKING, Callable, ClassVar, Optional, TypeVar -from onetl.impl.generic_options import GenericOptions - try: from pydantic.v1 import Field, PrivateAttr, SecretStr, validator except (ImportError, AttributeError): from pydantic import Field, PrivateAttr, SecretStr, validator # type: ignore[no-redef, assignment] +from onetl._metrics.command import SparkCommandMetrics from onetl._util.java import get_java_gateway, try_import_java_class -from onetl._util.spark import get_spark_version, stringify +from onetl._util.spark import estimate_dataframe_size, get_spark_version, stringify from onetl._util.sql import clear_statement from onetl._util.version import Version from onetl.connection.db_connection.jdbc_mixin.options import ( @@ -29,7 +28,7 @@ ) from onetl.exception import MISSING_JVM_CLASS_MSG from onetl.hooks import slot, support_hooks -from onetl.impl import FrozenModel +from onetl.impl import FrozenModel, GenericOptions from onetl.log import log_lines if TYPE_CHECKING: @@ -204,20 +203,27 @@ def fetch( log.info("|%s| Executing SQL query (on driver):", self.__class__.__name__) log_lines(log, query) - df = self._query_on_driver( - query, - ( - self.FetchOptions.parse(options.dict()) # type: ignore - if isinstance(options, JDBCMixinOptions) - else self.FetchOptions.parse(options) - ), + call_options = ( + self.FetchOptions.parse(options.dict()) # type: ignore + if isinstance(options, JDBCMixinOptions) + else self.FetchOptions.parse(options) ) - log.info( - "|%s| Query succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - df.count(), - ) + try: + df = self._query_on_driver(query, call_options) + except Exception: + log.error("|%s| Query failed!", self.__class__.__name__) + raise + + log.info("|%s| Query succeeded, created in-memory dataframe.", self.__class__.__name__) + + # as we don't actually use Spark for this method, SparkMetricsRecorder is useless. + # Just create metrics by hand, and fill them up using information based on dataframe content. + metrics = SparkCommandMetrics() + metrics.input.read_rows = df.count() + metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df) + log.info("|%s| Recorded metrics:", self.__class__.__name__) + log_lines(log, str(metrics)) return df @slot @@ -273,17 +279,26 @@ def execute( if isinstance(options, JDBCMixinOptions) else self.ExecuteOptions.parse(options) ) - df = self._call_on_driver(statement, call_options) - - if df is not None: - rows_count = df.count() - log.info( - "|%s| Execution succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - rows_count, - ) - else: - log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) + + try: + df = self._call_on_driver(statement, call_options) + except Exception: + log.error("|%s| Execution failed!", self.__class__.__name__) + raise + + if not df: + log.info("|%s| Execution succeeded, nothing returned.", self.__class__.__name__) + return None + + log.info("|%s| Execution succeeded, created in-memory dataframe.", self.__class__.__name__) + # as we don't actually use Spark for this method, SparkMetricsRecorder is useless. + # Just create metrics by hand, and fill them up using information based on dataframe content. + metrics = SparkCommandMetrics() + metrics.input.read_rows = df.count() + metrics.driver.in_memory_bytes = estimate_dataframe_size(self.spark, df) + + log.info("|%s| Recorded metrics:", self.__class__.__name__) + log_lines(log, str(metrics)) return df @validator("spark") diff --git a/onetl/connection/db_connection/oracle/connection.py b/onetl/connection/db_connection/oracle/connection.py index 043989500..c76693613 100644 --- a/onetl/connection/db_connection/oracle/connection.py +++ b/onetl/connection/db_connection/oracle/connection.py @@ -20,14 +20,12 @@ from etl_entities.instance import Host from onetl._util.classproperty import classproperty -from onetl._util.sql import clear_statement from onetl._util.version import Version from onetl.connection.db_connection.jdbc_connection import JDBCConnection from onetl.connection.db_connection.jdbc_connection.options import JDBCReadOptions from onetl.connection.db_connection.jdbc_mixin.options import ( JDBCExecuteOptions, JDBCFetchOptions, - JDBCOptions, ) from onetl.connection.db_connection.oracle.dialect import OracleDialect from onetl.connection.db_connection.oracle.options import ( @@ -43,8 +41,6 @@ from onetl.log import BASE_LOG_INDENT, log_lines # do not import PySpark here, as we allow user to use `Oracle.get_packages()` for creating Spark session - - if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -290,32 +286,6 @@ def get_min_max_values( max_value = int(max_value) return min_value, max_value - @slot - def execute( - self, - statement: str, - options: JDBCOptions | JDBCExecuteOptions | dict | None = None, # noqa: WPS437 - ) -> DataFrame | None: - statement = clear_statement(statement) - - log.info("|%s| Executing statement (on driver):", self.__class__.__name__) - log_lines(log, statement) - - call_options = self.ExecuteOptions.parse(options) - df = self._call_on_driver(statement, call_options) - self._handle_compile_errors(statement.strip(), call_options) - - if df is not None: - rows_count = df.count() - log.info( - "|%s| Execution succeeded, resulting in-memory dataframe contains %d rows", - self.__class__.__name__, - rows_count, - ) - else: - log.info("|%s| Execution succeeded, nothing returned", self.__class__.__name__) - return df - @root_validator def _only_one_of_sid_or_service_name(cls, values): sid = values.get("sid") @@ -329,6 +299,15 @@ def _only_one_of_sid_or_service_name(cls, values): return values + def _call_on_driver( + self, + query: str, + options: JDBCExecuteOptions, + ) -> DataFrame | None: + result = super()._call_on_driver(query, options) + self._handle_compile_errors(query.strip(), options) + return result + def _parse_create_statement(self, statement: str) -> tuple[str, str, str] | None: """ Parses ``CREATE ... type_name [schema.]object_name ...`` statement diff --git a/onetl/db/db_writer/db_writer.py b/onetl/db/db_writer/db_writer.py index 666fce87e..06dbd44c5 100644 --- a/onetl/db/db_writer/db_writer.py +++ b/onetl/db/db_writer/db_writer.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from logging import getLogger +import logging from typing import TYPE_CHECKING, Optional try: @@ -10,12 +10,15 @@ except (ImportError, AttributeError): from pydantic import Field, PrivateAttr, validator # type: ignore[no-redef, assignment] +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.recorder import SparkMetricsRecorder from onetl.base import BaseDBConnection from onetl.hooks import slot, support_hooks from onetl.impl import FrozenModel, GenericOptions from onetl.log import ( entity_boundary_log, log_dataframe_schema, + log_lines, log_options, log_with_indent, ) @@ -23,7 +26,7 @@ if TYPE_CHECKING: from pyspark.sql import DataFrame -log = getLogger(__name__) +log = logging.getLogger(__name__) @support_hooks @@ -172,7 +175,7 @@ def validate_options(cls, options, values): return None @slot - def run(self, df: DataFrame): + def run(self, df: DataFrame) -> None: """ Method for writing your df to specified target. |support_hooks| @@ -188,7 +191,7 @@ def run(self, df: DataFrame): Examples -------- - Write df to target: + Write dataframe to target: .. code:: python @@ -198,18 +201,37 @@ def run(self, df: DataFrame): raise ValueError(f"DataFrame is streaming. {self.__class__.__name__} supports only batch DataFrames.") entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() starts") - if not self._connection_checked: self._log_parameters() log_dataframe_schema(log, df) self.connection.check() self._connection_checked = True - self.connection.write_df_to_target( - df=df, - target=str(self.target), - **self._get_write_kwargs(), - ) + with SparkMetricsRecorder(self.connection.spark) as recorder: + try: + self.connection.write_df_to_target( + df=df, + target=str(self.target), + **self._get_write_kwargs(), + ) + except Exception: + metrics = recorder.metrics() + # SparkListener is not a reliable source of information, metrics may or may not be present. + # Because of this we also do not return these metrics as method result + if metrics.output.is_empty: + log.error( + "|%s| Error while writing dataframe.", + self.__class__.__name__, + ) + else: + log.error( + "|%s| Error while writing dataframe. Target MAY contain partially written data!", + self.__class__.__name__, + ) + self._log_metrics(metrics) + raise + finally: + self._log_metrics(recorder.metrics()) entity_boundary_log(log, msg=f"{self.__class__.__name__}.run() ends", char="-") @@ -225,3 +247,8 @@ def _get_write_kwargs(self) -> dict: return {"options": self.options} return {} + + def _log_metrics(self, metrics: SparkCommandMetrics) -> None: + if not metrics.is_empty: + log.debug("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) diff --git a/onetl/file/file_df_writer/file_df_writer.py b/onetl/file/file_df_writer/file_df_writer.py index a80f54801..6431219af 100644 --- a/onetl/file/file_df_writer/file_df_writer.py +++ b/onetl/file/file_df_writer/file_df_writer.py @@ -10,6 +10,8 @@ except (ImportError, AttributeError): from pydantic import PrivateAttr, validator # type: ignore[no-redef, assignment] +from onetl._metrics.command import SparkCommandMetrics +from onetl._metrics.recorder import SparkMetricsRecorder from onetl.base import BaseFileDFConnection, BaseWritableFileFormat, PurePathProtocol from onetl.file.file_df_writer.options import FileDFWriterOptions from onetl.hooks import slot, support_hooks @@ -17,6 +19,7 @@ from onetl.log import ( entity_boundary_log, log_dataframe_schema, + log_lines, log_options, log_with_indent, ) @@ -125,12 +128,32 @@ def run(self, df: DataFrame) -> None: self.connection.check() self._connection_checked = True - self.connection.write_df_as_files( - df=df, - path=self.target_path, - format=self.format, - options=self.options, - ) + with SparkMetricsRecorder(self.connection.spark) as recorder: + try: + self.connection.write_df_as_files( + df=df, + path=self.target_path, + format=self.format, + options=self.options, + ) + except Exception: + metrics = recorder.metrics() + if metrics.output.is_empty: + # SparkListener is not a reliable source of information, metrics may or may not be present. + # Because of this we also do not return these metrics as method result + log.error( + "|%s| Error while writing dataframe.", + self.__class__.__name__, + ) + else: + log.error( + "|%s| Error while writing dataframe. Target MAY contain partially written data!", + self.__class__.__name__, + ) + self._log_metrics(metrics) + raise + finally: + self._log_metrics(recorder.metrics()) entity_boundary_log(log, f"{self.__class__.__name__}.run() ends", char="-") @@ -143,6 +166,11 @@ def _log_parameters(self, df: DataFrame) -> None: log_options(log, options_dict) log_dataframe_schema(log, df) + def _log_metrics(self, metrics: SparkCommandMetrics) -> None: + if not metrics.is_empty: + log.debug("|%s| Recorded metrics (some values may be missing!):", self.__class__.__name__) + log_lines(log, str(metrics), level=logging.DEBUG) + @validator("target_path", pre=True) def _validate_target_path(cls, target_path, values): connection: BaseFileDFConnection = values["connection"] diff --git a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py index b72f8ac1b..6cea95cca 100644 --- a/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py +++ b/tests/tests_integration/tests_db_connection_integration/test_postgres_integration.py @@ -1007,7 +1007,7 @@ def test_postgres_connection_sql_options( processing.assert_equal_df(df=df, other_frame=table_df) -def test_postgres_fetch_with_legacy_jdbc_options(spark, processing): +def test_postgres_connection_fetch_with_legacy_jdbc_options(spark, processing): postgres = Postgres( host=processing.host, port=processing.port, @@ -1023,7 +1023,7 @@ def test_postgres_fetch_with_legacy_jdbc_options(spark, processing): assert df is not None -def test_postgres_execute_with_legacy_jdbc_options(spark, processing): +def test_postgres_connection_execute_with_legacy_jdbc_options(spark, processing): postgres = Postgres( host=processing.host, port=processing.port,