Skip to content

Commit

Permalink
Merge branch 'CLOUD-8218/spark_df_failed_rows_100_limit' of github.co…
Browse files Browse the repository at this point in the history
…m:sodadata/soda-core into CLOUD-8218/spark_df_failed_rows_100_limit
  • Loading branch information
jzalucki committed Aug 1, 2024
2 parents b86f16e + 4731f2f commit 67cd941
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 49 deletions.
83 changes: 45 additions & 38 deletions soda/core/soda/execution/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def execute(self):
# TODO: some of the subclasses couple setting metric with storing the sample - refactor that.
self.fetchall()

def _cursor_execute_exception_handler(self, e):
data_source = self.data_source_scan.data_source
self.exception = e
self.logs.error(
message=f"Query execution error in {self.query_name}: {e}\n{self.sql}",
exception=e,
location=self.location,
)
data_source.query_failed(e)

def _execute_cursor(self, execute=True):
"""
Execute the SQL query and yield the cursor for further processing.
Expand All @@ -143,13 +153,7 @@ def _execute_cursor(self, execute=True):
cursor.reset()
cursor.close()
except BaseException as e:
self.exception = e
self.logs.error(
message=f"Query execution error in {self.query_name}: {e}\n{self.sql}",
exception=e,
location=self.location,
)
data_source.query_failed(e)
self._cursor_execute_exception_handler(e)
finally:
self.duration = datetime.now() - start

Expand Down Expand Up @@ -199,37 +203,40 @@ def store(self):
if hasattr(self, "metric") and self.metric and self.metric.value == undefined:
set_metric = True

if set_metric or allow_samples:
self.logs.debug(f"Query {self.query_name}:\n{self.sql}")
cursor.execute(str(self.sql))
self.description = cursor.description
db_sample = DbSample(cursor, self.data_source_scan.data_source, self.samples_limit)

if set_metric:
self.metric.set_value(db_sample.get_rows_count())

if allow_samples:
# TODO Hacky way to get the check name, check name isn't there when dataset samples are taken
check_name = next(iter(self.metric.checks)).name if hasattr(self, "metric") else None
sample_context = SampleContext(
sample=db_sample,
sample_name=self.sample_name,
query=self.sql,
data_source=self.data_source_scan.data_source,
partition=self.partition,
column=self.column,
scan=self.data_source_scan.scan,
logs=self.data_source_scan.scan._logs,
samples_limit=self.samples_limit,
passing_sql=self.passing_sql,
check_name=check_name,
)

self.sample_ref = sampler.store_sample(sample_context)
else:
self.logs.info(
f"Skipping samples from query '{self.query_name}'. Excluded column(s) present: {offending_columns}."
)
try:
if set_metric or allow_samples:
self.logs.debug(f"Query {self.query_name}:\n{self.sql}")
cursor.execute(str(self.sql))
self.description = cursor.description
db_sample = DbSample(cursor, self.data_source_scan.data_source, self.samples_limit)

if set_metric:
self.metric.set_value(db_sample.get_rows_count())

if allow_samples:
# TODO Hacky way to get the check name, check name isn't there when dataset samples are taken
check_name = next(iter(self.metric.checks)).name if hasattr(self, "metric") else None
sample_context = SampleContext(
sample=db_sample,
sample_name=self.sample_name,
query=self.sql,
data_source=self.data_source_scan.data_source,
partition=self.partition,
column=self.column,
scan=self.data_source_scan.scan,
logs=self.data_source_scan.scan._logs,
samples_limit=self.samples_limit,
passing_sql=self.passing_sql,
check_name=check_name,
)

self.sample_ref = sampler.store_sample(sample_context)
else:
self.logs.info(
f"Skipping samples from query '{self.query_name}'. Excluded column(s) present: {offending_columns}."
)
except BaseException as e:
self._cursor_execute_exception_handler(e)

def __append_to_scan(self):
scan = self.data_source_scan.scan
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from helpers.common_test_tables import customers_test_table, customers_huge_test_table
from helpers.common_test_tables import customers_huge_test_table, customers_test_table
from helpers.data_source_fixture import DataSourceFixture
from helpers.mock_http_request import MockHttpRequest
from helpers.mock_http_sampler import MockHttpSampler
Expand Down
18 changes: 10 additions & 8 deletions soda/core/tests/helpers/common_test_tables.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
from faker import Faker
from random import randint, uniform, choice
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from datetime import date, datetime, timezone, timedelta
from random import choice, randint, uniform

from faker import Faker
from helpers.test_table import TestTable
from soda.execution.data_type import DataType

utc = timezone.utc


def generate_customer_records(num_records):
fake = Faker()
start_date = datetime(2020, 6, 24, 0, 4, 10)
Expand All @@ -25,18 +26,19 @@ def generate_customer_records(num_records):
fake.word(), # cst_size_txt
randint(1, 10000), # distance
f"{randint(0, 100)}%", # pct
choice(['A', 'B', 'C', 'D']), # cat
choice(['PL', 'BE', 'NL', 'US']), # country
choice(["A", "B", "C", "D"]), # cat
choice(["PL", "BE", "NL", "US"]), # country
fake.zipcode(), # zip
fake.email(), # email
fake.date_this_century(), # date_updated
current_date, # ts
current_date_with_tz # ts_with_tz
current_date_with_tz, # ts_with_tz
)
records.append(record)

return records


customers_test_table = TestTable(
name="Customers",
create_view=os.getenv("TEST_WITH_VIEWS", False),
Expand Down Expand Up @@ -89,7 +91,7 @@ def generate_customer_records(num_records):
("ts", DataType.TIMESTAMP),
("ts_with_tz", DataType.TIMESTAMP_TZ),
],
values=generate_customer_records(120)
values=generate_customer_records(120),
)

customers_dist_check_test_table = TestTable(
Expand Down
2 changes: 1 addition & 1 deletion soda/core/tests/helpers/mock_soda_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def find_check(self, check_index: int) -> dict | None:
checks = scan_result["checks"]
assert len(checks) > check_index
return checks[check_index]

def find_check_metric(self, metric_name: str) -> dict | None:
assert len(self.scan_results) > 0
scan_result = self.scan_results[0]
Expand Down
2 changes: 1 addition & 1 deletion soda/spark/soda/data_sources/spark_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_table_columns(
),
)
query.execute()
if len(query.rows) > 0:
if query.rows and len(query.rows) > 0:
rows = query.rows
# Remove the partitioning information (see https://spark.apache.org/docs/latest/sql-ref-syntax-aux-describe-table.html)
partition_indices = [i for i in range(len(rows)) if rows[i][0].startswith("# Partition")]
Expand Down

0 comments on commit 67cd941

Please sign in to comment.