Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 29, 2024
1 parent b780806 commit bdc49b9
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 73 deletions.
82 changes: 40 additions & 42 deletions sdgx/data_models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def check_column_list(cls, value) -> Any:
"""

def get_column_encoder_by_categorical_threshold(
self, num_categories: int
self, num_categories: int
) -> Union[CategoricalEncoderType, None]:
encoder_type = None
if self.categorical_threshold is None:
Expand Down Expand Up @@ -135,16 +135,16 @@ def __eq__(self, other):
if not isinstance(other, Metadata):
return super().__eq__(other)
return (
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
set(self.tag_fields) == set(other.tag_fields)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.tag_fields, other.tag_fields))
)
and all(
self.get(key) == other.get(key)
for key in set(chain(self.format_fields, other.format_fields))
)
and self.version == other.version
)

def query(self, field: str) -> Iterable[str]:
Expand Down Expand Up @@ -204,9 +204,9 @@ def set(self, key: str, value: Any):

old_value = self.get(key)
if (
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
key in self.model_fields
and key not in self.tag_fields
and key not in self.format_fields
):
raise MetadataInitError(
f"Set {key} not in tag_fields, try set it directly as m.{key} = value"
Expand Down Expand Up @@ -302,14 +302,14 @@ def update(self, attributes: dict[str, Any]):

@classmethod
def from_dataloader(
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
dataloader: DataLoader,
max_chunk: int = 10,
primary_keys: Set[str] = None,
include_inspectors: Iterable[str] | None = None,
exclude_inspectors: Iterable[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataLoader and Inspectors
Expand Down Expand Up @@ -369,12 +369,12 @@ def from_dataloader(

@classmethod
def from_dataframe(
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
cls,
df: pd.DataFrame,
include_inspectors: list[str] | None = None,
exclude_inspectors: list[str] | None = None,
inspector_init_kwargs: dict[str, Any] | None = None,
check: bool = False,
) -> "Metadata":
"""Initialize a metadata from DataFrame and Inspectors
Expand Down Expand Up @@ -558,10 +558,10 @@ def get_column_data_type(self, column_name: str):
# find the dtype who has most high inspector level
for each_key in list(self.model_fields.keys()) + list(self._extend.keys()):
if (
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
each_key != "pii_columns"
and each_key.endswith("_columns")
and column_name in self.get(each_key)
and current_level < self.column_inspect_level[each_key]
):
current_level = self.column_inspect_level[each_key]
current_type = each_key
Expand All @@ -582,7 +582,9 @@ def get_column_pii(self, column_name: str):
return True
return False

def change_column_type(self, column_names: str | List[str], column_original_type: str, column_new_type: str):
def change_column_type(
self, column_names: str | List[str], column_original_type: str, column_new_type: str
):
"""Change the type of column."""
if not column_names:
return
Expand Down Expand Up @@ -629,16 +631,12 @@ def do_remove_columns(key, get=True, to_removes=column_names):
res = [item for item in target if item not in to_removes]
elif isinstance(target, dict):
if key == "numeric_format":
obj.set(key, {
k: {
v2 for v2 in v if v2 not in to_removes
} for k, v in target.items()
})
obj.set(
key,
{k: {v2 for v2 in v if v2 not in to_removes} for k, v in target.items()},
)
else:
res = {
k: v for k, v in target.items()
if k not in to_removes
}
res = {k: v for k, v in target.items() if k not in to_removes}
elif isinstance(target, set):
res = target.difference(to_removes)

Expand Down
4 changes: 2 additions & 2 deletions sdgx/data_processors/formatters/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def convert_timestamp_to_datetime(timestamp_column_list, format_dict, processed_
TODO:
if the value <0, the result will be `No Datetime`, try to fix it.
"""

def column_timestamp_formatter(each_stamp: int, timestamp_format: str) -> str:
try:
each_str = datetime.fromtimestamp(each_stamp).strftime(timestamp_format)
Expand All @@ -216,8 +217,7 @@ def column_timestamp_formatter(each_stamp: int, timestamp_format: str) -> str:
if column in result_data.columns:
# Convert the timestamp to datetime format using the format provided in datetime_column_dict
result_data[column] = result_data[column].apply(
column_timestamp_formatter,
timestamp_format=format_dict[column]
column_timestamp_formatter, timestamp_format=format_dict[column]
)
else:
logger.error(f"Column {column} not in processed data's column list!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def _fit_discrete(self, data, encoder_type: CategoricalEncoderType = None):

# Load encoder from metadata
if encoder_type is None and self.metadata:
selected_encoder_type = encoder_type = self.metadata.get_column_encoder_by_name(column_name)
selected_encoder_type = encoder_type = self.metadata.get_column_encoder_by_name(
column_name
)
# if the encoder is not be specified, using onehot.
if encoder_type is None:
encoder_type = "onehot"
Expand Down
7 changes: 5 additions & 2 deletions sdgx/models/ml/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
from sdgx.models.components.optimize.sdv_ctgan.data_sampler import DataSampler
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BaseSynthesizer as SDVBaseSynthesizer, BatchedSynthesizer,
BaseSynthesizer as SDVBaseSynthesizer,
)
from sdgx.models.components.sdv_ctgan.synthesizers.base import (
BatchedSynthesizer,
random_state,
)
from sdgx.models.components.sdv_ctgan.synthesizers.base import random_state
from sdgx.models.extension import hookimpl
from sdgx.models.ml.single_table.base import MLSynthesizerModel
from sdgx.utils import logger
Expand Down
4 changes: 1 addition & 3 deletions sdgx/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,7 @@ def _sample_once(
batch_size = self.model.get_batch_size()
multiply_factor = 1.2
if isinstance(self.model, CTGANSynthesizerModel):
model_sample_args = {
"drop_more": False
}
model_sample_args = {"drop_more": False}

while missing_count > 0 and max_trails > 0:
sample_data = self.model.sample(
Expand Down
3 changes: 3 additions & 0 deletions tests/data_models/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_metadata(metadata: Metadata):

assert metadata._dump_json()


def test_change_metadata(metadata: Metadata):
metadata = metadata.model_copy()
col = "age"
Expand All @@ -46,13 +47,15 @@ def test_change_metadata(metadata: Metadata):
assert col in metadata.int_columns
assert col not in metadata.datetime_columns


def test_remove_metadata(metadata: Metadata):
metadata = metadata.model_copy()
col = "age"
assert col in metadata.int_columns
metadata.remove_column([col])
assert col not in metadata.int_columns


def test_metadata_save_load(metadata: Metadata, tmp_path: Path):
test_path = tmp_path / "metadata_path_test.json"
metadata.save(test_path)
Expand Down
45 changes: 22 additions & 23 deletions tests/test_ctgan_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@

from sdgx.data_connectors.dataframe_connector import DataFrameConnector
from sdgx.data_models.metadata import Metadata
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import DataTransformer, SpanInfo
from sdgx.models.components.sdv_rdt.transformers.categorical import NormalizedFrequencyEncoder, NormalizedLabelEncoder, \
OneHotEncoder
from sdgx.models.components.optimize.sdv_ctgan.data_transformer import (
DataTransformer,
SpanInfo,
)
from sdgx.models.components.sdv_rdt.transformers.categorical import (
NormalizedFrequencyEncoder,
NormalizedLabelEncoder,
OneHotEncoder,
)
from sdgx.models.components.sdv_rdt.transformers.numerical import ClusterBasedNormalizer
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
Expand All @@ -34,7 +40,7 @@ def demo_single_table_data_pos_neg():
"cat_date": [fake.date() for _ in range(row_cnt)],
"cat_freq": [str(i) for i in range(row_cnt)],
"cat_thres_freq": [str(i) for i in range(100)] * (row_cnt // 100),
"cat_thres_label": [str(i) for i in range(200)] * (row_cnt // 200)
"cat_thres_label": [str(i) for i in range(200)] * (row_cnt // 200),
}
header = X.keys()
yield pd.DataFrame(X, columns=list(header))
Expand All @@ -46,29 +52,21 @@ def demo_single_table_data_pos_neg_metadata(demo_single_table_data_pos_neg):
metadata.categorical_encoder = {
"cat_onehot": "onehot",
"cat_label": "label",
"cat_freq": "frequency"
}
metadata.datetime_format = {
"cat_date": "%Y-%m-%d"
}
metadata.categorical_threshold = {
99: "frequency",
199: "label"
"cat_freq": "frequency",
}
metadata.datetime_format = {"cat_date": "%Y-%m-%d"}
metadata.categorical_threshold = {99: "frequency", 199: "label"}
yield metadata


@pytest.fixture
def demo_single_table_data_pos_neg_connector(demo_single_table_data_pos_neg):
yield DataFrameConnector(
df=demo_single_table_data_pos_neg
)
yield DataFrameConnector(df=demo_single_table_data_pos_neg)


@pytest.fixture
def ctgan_synthesizer(
demo_single_table_data_pos_neg_connector,
demo_single_table_data_pos_neg_metadata
demo_single_table_data_pos_neg_connector, demo_single_table_data_pos_neg_metadata
):
yield Synthesizer(
metadata=demo_single_table_data_pos_neg_metadata,
Expand All @@ -78,8 +76,7 @@ def ctgan_synthesizer(


def test_ctgan_synthesizer_with_pos_neg(
ctgan_synthesizer: Synthesizer,
demo_single_table_data_pos_neg
ctgan_synthesizer: Synthesizer, demo_single_table_data_pos_neg
):
original_data = demo_single_table_data_pos_neg

Expand All @@ -94,7 +91,7 @@ def test_ctgan_synthesizer_with_pos_neg(
for item in transform_list:
span_info: List[SpanInfo] = item.output_info
col_dim = item.output_dimensions
current_data = transformed_data[:, current_dim:current_dim + col_dim]
current_data = transformed_data[:, current_dim : current_dim + col_dim]
current_dim += col_dim
col = item.column_name
if col in ["cat_freq", "cat_thres_freq"]:
Expand All @@ -109,7 +106,9 @@ def test_ctgan_synthesizer_with_pos_neg(
assert col_dim == 1
assert len(span_info) == 1
assert span_info[0].activation_fn == "liner"
assert len(item.transform.categories_to_values.keys()) == original_data[col].nunique(dropna=False)
assert len(item.transform.categories_to_values.keys()) == original_data[col].nunique(
dropna=False
)
assert (current_data >= -1).all() and (current_data <= 1).all()
elif col in ["cat_onehot"]:
assert isinstance(item.transform, OneHotEncoder)
Expand All @@ -136,12 +135,12 @@ def test_ctgan_synthesizer_with_pos_neg(
if is_all_positive:
# Assert that the sampled_data column is also all positive
assert (
sampled_data[column] >= 0
sampled_data[column] >= 0
).all(), f"Column '{column}' in sampled data should be all positive."
elif is_all_negative:
# Assert that the sampled_data column is also all negative
assert (
sampled_data[column] <= 0
sampled_data[column] <= 0
).all(), f"Column '{column}' in sampled data should be all negative."


Expand Down

0 comments on commit bdc49b9

Please sign in to comment.