diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 6989f691f..f93967adc 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -1,6 +1,7 @@ """Hierarchical Modeling Algorithms.""" import logging +import warnings from collections import defaultdict from copy import deepcopy @@ -552,7 +553,12 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): parameters = self._extract_parameters(parent_row, child_name, foreign_key) default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.get_table_metadata(child_name) - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', message=".*The 'SingleTableMetadata' is deprecated.*" + ) + synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer._set_parameters(parameters, default_parameters) else: synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}'] diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 9e7058116..2539ff98e 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -19,6 +19,7 @@ from sdv.datasets.local import load_csvs from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot +from sdv.metadata import MultiTableMetadata from sdv.metadata.metadata import Metadata from sdv.multi_table import HMASynthesizer from tests.integration.single_table.custom_constraints import MyConstraint @@ -2637,3 +2638,20 @@ def test_column_order(): assert table_1_column != list(data['table_1'].columns) assert table_1_column == ['col_1', 'col_2', 'col_3'] assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C'] + + +def test_no_deprecation_warning_single_table_metadata_sampling(): + """Test that no single-table metadata deprecation warning raises with `MultiTableMetadata`.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + multi_metadata = MultiTableMetadata() + multi_metadata.detect_from_dataframes(data) + synthesizer = HMASynthesizer(multi_metadata) + synthesizer.fit(data) + + # Run + with warnings.catch_warnings(record=True) as captured_warnings: + synthesizer.sample() + + # Assert + assert len(captured_warnings) == 0