Skip to content

Commit

Permalink
Include stack trace when sampling errors are surfaced (#2329)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored Jan 2, 2025
1 parent 0d203be commit 3e676fa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def handle_sampling_error(output_file_path, sampling_error):
)

if error_msg:
raise type(sampling_error)(error_msg + '\n' + str(sampling_error))
raise type(sampling_error)(error_msg) from sampling_error

raise sampling_error

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,15 +1546,17 @@ def test__sample_with_progress_bar_without_output_filepath(self):
instance._fitted = True
expected_message = re.escape(
'Error: Sampling terminated. No results were saved due to unspecified '
'"output_file_path".\nMocked Error'
'"output_file_path".'
)
instance._sample_in_batches.side_effect = RuntimeError('Mocked Error')

# Run and Assert
with pytest.raises(RuntimeError, match=expected_message):
with pytest.raises(RuntimeError, match=expected_message) as exception:
BaseSingleTableSynthesizer._sample_with_progress_bar(
instance, output_file_path=None, num_rows=10
)
assert isinstance(exception.value.__cause__, RuntimeError)
assert 'Mocked Error' in str(exception.value.__cause__)

@patch('sdv.single_table.base.datetime')
def test_sample(self, mock_datetime, caplog):
Expand Down
12 changes: 9 additions & 3 deletions tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,13 @@ def test_unflatten_dict():
def test_handle_sampling_error_temp_file():
"""Test that an error is raised when temp dir is ``False``."""
# Run and Assert
error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.\nTest error'
with pytest.raises(ValueError, match=error_msg):
error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.'
with pytest.raises(ValueError, match=error_msg) as exception:
handle_sampling_error('test.csv', ValueError('Test error'))

assert isinstance(exception.value.__cause__, ValueError)
assert 'Test error' in str(exception.value.__cause__)


def test_handle_sampling_error_false_temp_file_none_output_file():
"""Test the ``handle_sampling_error`` function.
Expand All @@ -228,9 +231,12 @@ def test_handle_sampling_error_false_temp_file_none_output_file():
"""
# Run and Assert
error_msg = 'Test error'
with pytest.raises(ValueError, match=error_msg):
with pytest.raises(ValueError) as exception:
handle_sampling_error('test.csv', ValueError('Test error'))

assert isinstance(exception.value.__cause__, ValueError)
assert error_msg in str(exception.value.__cause__)


def test_handle_sampling_error_ignore():
"""Test that the error is raised if the error is the no rows error."""
Expand Down

0 comments on commit 3e676fa

Please sign in to comment.