Skip to content

Commit

Permalink
Merge branch 'main' into CLOUD-8218/spark_df_failed_rows_100_limit
Browse files Browse the repository at this point in the history
  • Loading branch information
jzalucki authored Aug 1, 2024
2 parents 126a32b + 8741aac commit 4731f2f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 39 deletions.
83 changes: 45 additions & 38 deletions soda/core/soda/execution/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def execute(self):
# TODO: some of the subclasses couple setting metric with storing the sample - refactor that.
self.fetchall()

def _cursor_execute_exception_handler(self, e):
data_source = self.data_source_scan.data_source
self.exception = e
self.logs.error(
message=f"Query execution error in {self.query_name}: {e}\n{self.sql}",
exception=e,
location=self.location,
)
data_source.query_failed(e)

def _execute_cursor(self, execute=True):
"""
Execute the SQL query and yield the cursor for further processing.
Expand All @@ -143,13 +153,7 @@ def _execute_cursor(self, execute=True):
cursor.reset()
cursor.close()
except BaseException as e:
self.exception = e
self.logs.error(
message=f"Query execution error in {self.query_name}: {e}\n{self.sql}",
exception=e,
location=self.location,
)
data_source.query_failed(e)
self._cursor_execute_exception_handler(e)
finally:
self.duration = datetime.now() - start

Expand Down Expand Up @@ -199,37 +203,40 @@ def store(self):
if hasattr(self, "metric") and self.metric and self.metric.value == undefined:
set_metric = True

if set_metric or allow_samples:
self.logs.debug(f"Query {self.query_name}:\n{self.sql}")
cursor.execute(str(self.sql))
self.description = cursor.description
db_sample = DbSample(cursor, self.data_source_scan.data_source, self.samples_limit)

if set_metric:
self.metric.set_value(db_sample.get_rows_count())

if allow_samples:
# TODO Hacky way to get the check name, check name isn't there when dataset samples are taken
check_name = next(iter(self.metric.checks)).name if hasattr(self, "metric") else None
sample_context = SampleContext(
sample=db_sample,
sample_name=self.sample_name,
query=self.sql,
data_source=self.data_source_scan.data_source,
partition=self.partition,
column=self.column,
scan=self.data_source_scan.scan,
logs=self.data_source_scan.scan._logs,
samples_limit=self.samples_limit,
passing_sql=self.passing_sql,
check_name=check_name,
)

self.sample_ref = sampler.store_sample(sample_context)
else:
self.logs.info(
f"Skipping samples from query '{self.query_name}'. Excluded column(s) present: {offending_columns}."
)
try:
if set_metric or allow_samples:
self.logs.debug(f"Query {self.query_name}:\n{self.sql}")
cursor.execute(str(self.sql))
self.description = cursor.description
db_sample = DbSample(cursor, self.data_source_scan.data_source, self.samples_limit)

if set_metric:
self.metric.set_value(db_sample.get_rows_count())

if allow_samples:
# TODO Hacky way to get the check name, check name isn't there when dataset samples are taken
check_name = next(iter(self.metric.checks)).name if hasattr(self, "metric") else None
sample_context = SampleContext(
sample=db_sample,
sample_name=self.sample_name,
query=self.sql,
data_source=self.data_source_scan.data_source,
partition=self.partition,
column=self.column,
scan=self.data_source_scan.scan,
logs=self.data_source_scan.scan._logs,
samples_limit=self.samples_limit,
passing_sql=self.passing_sql,
check_name=check_name,
)

self.sample_ref = sampler.store_sample(sample_context)
else:
self.logs.info(
f"Skipping samples from query '{self.query_name}'. Excluded column(s) present: {offending_columns}."
)
except BaseException as e:
self._cursor_execute_exception_handler(e)

def __append_to_scan(self):
scan = self.data_source_scan.scan
Expand Down
2 changes: 1 addition & 1 deletion soda/spark/soda/data_sources/spark_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def get_table_columns(
),
)
query.execute()
if len(query.rows) > 0:
if query.rows and len(query.rows) > 0:
rows = query.rows
# Remove the partitioning information (see https://spark.apache.org/docs/latest/sql-ref-syntax-aux-describe-table.html)
partition_indices = [i for i in range(len(rows)) if rows[i][0].startswith("# Partition")]
Expand Down

0 comments on commit 4731f2f

Please sign in to comment.