Skip to content

Commit

Permalink
Sampling with HMA Synthesizer generates many SingleTableMetadata de…
Browse files Browse the repository at this point in the history
…precation warnings (#2332)
  • Loading branch information
R-Palazzo authored Jan 8, 2025
1 parent 974da0e commit 0ac8308
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Hierarchical Modeling Algorithms."""

import logging
import warnings
from collections import defaultdict
from copy import deepcopy

Expand Down Expand Up @@ -552,7 +553,12 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})
table_meta = self.metadata.get_table_metadata(child_name)
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message=".*The 'SingleTableMetadata' is deprecated.*"
)
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])

synthesizer._set_parameters(parameters, default_parameters)
else:
synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}']
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sdv.datasets.local import load_csvs
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot
from sdv.metadata import MultiTableMetadata
from sdv.metadata.metadata import Metadata
from sdv.multi_table import HMASynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint
Expand Down Expand Up @@ -2637,3 +2638,20 @@ def test_column_order():
assert table_1_column != list(data['table_1'].columns)
assert table_1_column == ['col_1', 'col_2', 'col_3']
assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C']


def test_no_deprecation_warning_single_table_metadata_sampling():
"""Test that no single-table metadata deprecation warning raises with `MultiTableMetadata`."""
# Setup
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
multi_metadata = MultiTableMetadata()
multi_metadata.detect_from_dataframes(data)
synthesizer = HMASynthesizer(multi_metadata)
synthesizer.fit(data)

# Run
with warnings.catch_warnings(record=True) as captured_warnings:
synthesizer.sample()

# Assert
assert len(captured_warnings) == 0

0 comments on commit 0ac8308

Please sign in to comment.