diff --git a/sdgx/models/statistics/single_table/base.py b/sdgx/models/statistics/single_table/base.py index e4b630e0..ce8c1291 100644 --- a/sdgx/models/statistics/single_table/base.py +++ b/sdgx/models/statistics/single_table/base.py @@ -2,24 +2,28 @@ # Which is Lincensed by MIT License import os -from copy import deepcopy -from typing import List, Optional import numpy as np import torch +from sdgx.data_loader import DataLoader +from sdgx.data_models.metadata import Metadata +from sdgx.models.base import SynthesizerModel -class StatisticSynthesizerModel: + +class StatisticSynthesizerModel(SynthesizerModel): random_states = None - def __init__(self, transformer=None, sampler=None) -> None: + def __init__(self, transformer=None, sampler=None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._generator = None self.model = None self.status = "UNFINED" self.model_type = "MODEL_TYPE_UNDEFINED" # self.epochs = epochs self._device = "CPU" - def fit(self, input_df, discrete_cols: Optional[List] = None): + def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs): raise NotImplementedError def set_device(self, device): diff --git a/sdgx/models/statistics/single_table/copula.py b/sdgx/models/statistics/single_table/copula.py index bec24a31..309a657c 100644 --- a/sdgx/models/statistics/single_table/copula.py +++ b/sdgx/models/statistics/single_table/copula.py @@ -9,16 +9,13 @@ import sdgx.models.components.sdv_copulas as copulas from sdgx.data_loader import DataLoader from sdgx.data_models.metadata import Metadata -from sdgx.exceptions import NonParametricError, SynthesizerInitError +from sdgx.exceptions import NonParametricError from sdgx.models.components.optimize.sdv_copulas.data_transformer import ( StatisticDataTransformer, ) from sdgx.models.components.sdv_copulas import multivariate -from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer -from sdgx.models.components.sdv_rdt.transformers import OneHotEncoder from sdgx.models.components.utils import ( flatten_dict, - log_numerical_distributions_error, unflatten_dict, validate_numerical_distributions, ) @@ -27,7 +24,7 @@ LOGGER = logging.getLogger(__name__) -class GaussianCopulaSynthesizer(StatisticSynthesizerModel): +class GaussianCopulaSynthesizerModel(StatisticSynthesizerModel): """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. Args: diff --git a/tests/models/test_copula.py b/tests/models/test_copula.py index 7433b27d..ed30e8e6 100644 --- a/tests/models/test_copula.py +++ b/tests/models/test_copula.py @@ -1,10 +1,7 @@ -from pathlib import Path - import pandas as pd import pytest -from sdgx.models.statistics.single_table.copula import GaussianCopulaSynthesizer -from sdgx.utils import get_demo_single_table +from sdgx.models.statistics.single_table.copula import GaussianCopulaSynthesizerModel @pytest.fixture @@ -13,7 +10,7 @@ def dummy_data(dummy_single_table_path): def test_gaussian_copula(dummy_single_table_metadata, dummy_single_table_data_loader): - model = GaussianCopulaSynthesizer() + model = GaussianCopulaSynthesizerModel() model.fit(dummy_single_table_metadata, dummy_single_table_data_loader) sampled_data = model.sample(10)