From 095e73b131a50a5dec5ffd948be75a836a0638cc Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 16 Aug 2023 09:59:49 -0700 Subject: [PATCH 1/3] Warning fix --- ctgan/data_transformer.py | 2 +- tests/unit/test_data_transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index 3b6a8822..a44ac5a2 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -46,7 +46,7 @@ def _fit_continuous(self, data): A ``ColumnTransformInfo`` object. """ column_name = data.columns[0] - gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=min(len(data), 10)) + gm = ClusterBasedNormalizer(max_clusters=min(len(data), 10)) gm.fit(data, column_name) num_components = sum(gm.valid_component_indicator) diff --git a/tests/unit/test_data_transformer.py b/tests/unit/test_data_transformer.py index 19fa205c..d4a8e550 100644 --- a/tests/unit/test_data_transformer.py +++ b/tests/unit/test_data_transformer.py @@ -75,7 +75,7 @@ def test__fit_continuous_max_clusters(self, MockCBN): transformer._fit_continuous(data) # Assert - MockCBN.assert_called_once_with(model_missing_values=True, max_clusters=len(data)) + MockCBN.assert_called_once_with(max_clusters=len(data)) @patch('ctgan.data_transformer.OneHotEncoder') def test___fit_discrete(self, MockOHE): From 73d223ece5e2432d5585ac9433ce819869ebaf11 Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 17 Aug 2023 08:38:26 -0700 Subject: [PATCH 2/3] Feedback --- ctgan/data_transformer.py | 3 ++- tests/unit/test_data_transformer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index a44ac5a2..8fa4c721 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -46,7 +46,8 @@ def _fit_continuous(self, data): A ``ColumnTransformInfo`` object. """ column_name = data.columns[0] - gm = ClusterBasedNormalizer(max_clusters=min(len(data), 10)) + gm = ClusterBasedNormalizer( + missing_value_generation='from_column', max_clusters=min(len(data), 10)) gm.fit(data, column_name) num_components = sum(gm.valid_component_indicator) diff --git a/tests/unit/test_data_transformer.py b/tests/unit/test_data_transformer.py index d4a8e550..8559fc00 100644 --- a/tests/unit/test_data_transformer.py +++ b/tests/unit/test_data_transformer.py @@ -75,7 +75,8 @@ def test__fit_continuous_max_clusters(self, MockCBN): transformer._fit_continuous(data) # Assert - MockCBN.assert_called_once_with(max_clusters=len(data)) + MockCBN.assert_called_once_with( + missing_value_generation='from_column', max_clusters=len(data)) @patch('ctgan.data_transformer.OneHotEncoder') def test___fit_discrete(self, MockOHE): From 5c37650fdd1ef27f61cabe5a2db52ac77963d70e Mon Sep 17 00:00:00 2001 From: Felipe Date: Thu, 17 Aug 2023 09:51:49 -0700 Subject: [PATCH 3/3] Up rdt version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e1083ca4..e6b4e992 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ "torch>=1.8.0;python_version<'3.10'", "torch>=1.11.0;python_version>='3.10' and python_version<'3.11'", "torch>=2.0.0;python_version>='3.11'", - 'rdt>=1.3.0,<2.0', + 'rdt>=1.6.1,<2.0', ] setup_requires = [