Skip to content

Commit

Permalink
[DOP-20817] Handle case taskMetrics=null
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Oct 17, 2024
1 parent c18a4f2 commit a4ea2a6
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/313.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``SparkMetricsRecorder`` failing when receiving ``SparkListenerTaskEnd`` without ``taskMetrics`` (e.g. executor was killed by OOM).
2 changes: 2 additions & 0 deletions onetl/_metrics/listener/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class SparkListenerTaskMetrics:

@classmethod
def create(cls, task_metrics):
if not task_metrics:
return cls()

Check warning on line 65 in onetl/_metrics/listener/task.py

View check run for this annotation

Codecov / codecov/patch

onetl/_metrics/listener/task.py#L65

Added line #L65 was not covered by tests
return cls(
executor_run_time_milliseconds=task_metrics.executorRunTime(),
executor_cpu_time_nanoseconds=task_metrics.executorCpuTime(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,62 @@ def test_spark_metrics_recorder_file_df_writer_empty_input(
metrics = recorder.metrics()
assert not metrics.output.written_rows
assert not metrics.output.written_bytes


def test_spark_metrics_recorder_file_df_writer_driver_failed(
spark,
local_fs_file_df_connection_with_path,
file_df_dataframe,
):
local_fs, target_path = local_fs_file_df_connection_with_path

df = file_df_dataframe

writer = FileDFWriter(
connection=local_fs,
format=CSV(),
target_path=target_path,
options=FileDFWriter.Options(if_exists="error"),
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows
assert not metrics.output.written_bytes


def test_spark_metrics_recorder_file_df_writer_executor_failed(
spark,
local_fs_file_df_connection_with_path,
file_df_dataframe,
):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

@udf(returnType=IntegerType())
def raise_exception():
raise ValueError("Force task failure")

local_fs, target_path = local_fs_file_df_connection_with_path

failing_df = file_df_dataframe.select(raise_exception().alias("some"))

writer = FileDFWriter(
connection=local_fs,
format=CSV(),
target_path=target_path,
options=FileDFWriter.Options(if_exists="append"),
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(failing_df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows
assert not metrics.output.written_bytes
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from contextlib import suppress

import pytest

Expand Down Expand Up @@ -137,6 +138,53 @@ def test_spark_metrics_recorder_hive_write_empty(spark, processing, get_schema_t
assert not metrics.output.written_rows


def test_spark_metrics_recorder_hive_write_driver_failed(spark, processing, prepare_schema_table):
df = processing.create_spark_df(spark).limit(0)

mismatch_df = df.withColumn("mismatch", df.id_int)

hive = Hive(cluster="rnd-dwh", spark=spark)
writer = DBWriter(
connection=hive,
target=prepare_schema_table.full_name,
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(mismatch_df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows


def test_spark_metrics_recorder_hive_write_executor_failed(spark, processing, get_schema_table):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

df = processing.create_spark_df(spark).limit(0)

@udf(returnType=IntegerType())
def raise_exception():
raise ValueError("Force task failure")

failing_df = df.select(raise_exception().alias("some"))

hive = Hive(cluster="rnd-dwh", spark=spark)
writer = DBWriter(
connection=hive,
target=get_schema_table.full_name,
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(failing_df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows


def test_spark_metrics_recorder_hive_execute(request, spark, processing, get_schema_table):
df = processing.create_spark_df(spark)
view_name = rand_str()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from contextlib import suppress

import pytest

Expand Down Expand Up @@ -167,6 +168,67 @@ def test_spark_metrics_recorder_postgres_write_empty(spark, processing, get_sche
assert not metrics.output.written_rows


def test_spark_metrics_recorder_postgres_write_driver_failed(spark, processing, prepare_schema_table):
postgres = Postgres(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)
df = processing.create_spark_df(spark).limit(0)

mismatch_df = df.withColumn("mismatch", df.id_int)

writer = DBWriter(
connection=postgres,
target=prepare_schema_table.full_name,
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(mismatch_df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows


def test_spark_metrics_recorder_postgres_write_executor_failed(spark, processing, get_schema_table):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

postgres = Postgres(
host=processing.host,
port=processing.port,
user=processing.user,
password=processing.password,
database=processing.database,
spark=spark,
)

@udf(returnType=IntegerType())
def raise_exception():
raise ValueError("Force task failure")

df = processing.create_spark_df(spark).limit(0)
failing_df = df.select(raise_exception().alias("some"))

writer = DBWriter(
connection=postgres,
target=get_schema_table.full_name,
)

with SparkMetricsRecorder(spark) as recorder:
with suppress(Exception):
writer.run(failing_df)

time.sleep(0.1) # sleep to fetch late metrics from SparkListener
metrics = recorder.metrics()
assert not metrics.output.written_rows


def test_spark_metrics_recorder_postgres_fetch(spark, processing, load_table_data):
postgres = Postgres(
host=processing.host,
Expand Down

0 comments on commit a4ea2a6

Please sign in to comment.