Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling fixed column relationships by specific_combinations and SpecificCombinationTransformer. #236

Merged
merged 8 commits into from
Nov 18, 2024
1 change: 1 addition & 0 deletions docs/source/user_guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ The Guides for users section includes SDG usage for different scenarios.
Use CLI directly <cli>
Use SDG as a library <library>
Synthetic single-table data <single_table>
Synthetic single-table data with specific_combinations <single_table_column_combinations>
Synthetic multi-table data <multi_table>
Evaluation synthetic data <evaluation>
43 changes: 43 additions & 0 deletions docs/source/user_guides/single_table_column_combinations.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Synthetic single-table data with specific_combinations
==========================================


.. code-block:: python

import pandas as pd

from sdgx.data_connectors.csv_connector import CsvConnector
from sdgx.data_models.metadata import Metadata
from sdgx.models.ml.single_table.ctgan import CTGANSynthesizerModel
from sdgx.synthesizer import Synthesizer
from sdgx.utils import download_demo_data

dataset_csv = download_demo_data()
data_connector = CsvConnector(path=dataset_csv)

# Specific the fixed column combinations.
# It can be specified multiple combinations by different tuples. Here we only specify one.
metadata = Metadata.from_dataframe(pd.read_csv(dataset_csv))
combinations = {("education", "educational-num")}
metadata.update({"specific_combinations": combinations})

synthesizer = Synthesizer(
model=CTGANSynthesizerModel(epochs=1), # For quick demo
data_connector=data_connector,
metadata=metadata
)
synthesizer.fit()
sampled_data = synthesizer.sample(1000)
synthesizer.cleanup() # Clean all cache


from sdgx.metrics.column.jsd import JSD

JSD = JSD()


selected_columns = ["workclass"]
isDiscrete = True
metrics = JSD.calculate(data_connector.read(), sampled_data, selected_columns, isDiscrete)

print("JSD metric of column %s: %g" % (selected_columns[0], metrics))
1 change: 1 addition & 0 deletions sdgx/data_processors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DataProcessorManager(Manager):
preset_defalut_processors = [
p.lower()
for p in [
"SpecificCombinationTransformer",
"FixedCombinationTransformer",
"NonValueTransformer",
"OutlierTransformer",
Expand Down
111 changes: 89 additions & 22 deletions sdgx/data_processors/transformers/fixed_combination.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,89 @@ class FixedCombinationTransformer(Transformer):
"""
A transformer that handles columns with fixed combinations in a DataFrame.

This transformer identifies and processes columns that have fixed relationships (high covariance) in a given DataFrame.
It can remove these columns during the conversion process and restore them during the reverse conversion process.
This transformer goal to auto identifies and processes columns that have fixed relationships (high covariance) in
a given DataFrame.

The relationships between columns include:
- Numerical function relationships: assess them based on covariance between the columns.
- Categorical mapping relationships: check for duplicate values for each column.

Note that we support one-to-one mappings between columns now, and each corresponding relationship will not
include duplicate columns.

For example:
we detect that,
1 numerical relationship: (key1, Value1, Value2)
3 one-to-one relationships: (key1, Key2) , (Category1, Category2)

| Key1 | Key2 | Category1 | Category2 | Value1 | Value2 |
| :--: | :--: | :-------: | :-------: | :----: | :----: |
| 1 | A | 1001 | Apple | 10 | 20 |
| 2 | B | 1002 | Broccoli | 15 | 30 |
| 2 | B | 1001 | Apple | 20 | 20 |
"""

Attributes:
fixed_combinations (dict[str, set[str]]): A dictionary mapping column names to sets of column names that have fixed relationships with them.
fixed_combinations: dict[str, set[str]]
"""
A dictionary mapping column names to sets of column names that have fixed relationships with them.
"""

def __init__(self):
super().__init__() # Call the parent class's initialization method
# Initialize the variable in the initialization method
simplified_fixed_combinations: dict[str, set[str]]
"""
A dictionary mapping column names to sets of column names that have fixed relationships with them.
"""

self.fixed_combinations: dict[str, set[str]] = {}
"""
A dictionary mapping column names to sets of column names that have fixed relationships with them.
"""
column_mappings: dict[(str, str), dict[str, str]]
"""
A dictionary mapping tuples of column names to dictionaries of value mappings.
"""

self.simplified_fixed_combinations: dict[str, set[str]] = {}
"""
A dictionary mapping column names to sets of column names that have fixed relationships with them.
"""
is_been_specified: bool
"""
A boolean that flag if exist specific combinations by user.
If true, needn't running this auto detect transform.
"""

def __init__(self):
super().__init__()
self.fixed_combinations: dict[str, set[str]] = {}
self.simplified_fixed_combinations: dict[str, set[str]] = {}
self.column_mappings: dict[(str, str), dict[str, str]] = {}
self.is_been_specified = False

@property
def is_exist_fixed_combinations(self) -> bool:
"""
A dictionary mapping tuples of column names to dictionaries of value mappings.
A boolean that flag if inspector have inspected some fixed combinations.
If False, needn't running this auto detect transform.
"""
return bool(self.fixed_combinations)

def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
"""Fit the transformer and save the relationships between columns.

Args:
metadata (Metadata): Metadata object
"""
# Check if exist specific combinations by user. If True, needn't run this auto-detect transform.
if metadata.get("specific_combinations"):
logger.info(
"Fit data using FixedCombinationTransformer(been specified)... Finished (No action)."
)
self.is_been_specified = True
self.fitted = True
return

# Check if exist fixed combinations, if not exist, needn't run this auto-detect transform.
self.fixed_combinations = metadata.get("fixed_combinations")
jalr4ever marked this conversation as resolved.
Show resolved Hide resolved
if not self.is_exist_fixed_combinations:
logger.info(
"Fit data using FixedCombinationTransformer(not existed)... Finished (No action)."
)
self.fitted = True
return

# simplify the fixed_combinations, remove the symmetric and duplicate combinations
# Simplify the fixed_combinations, remove the symmetric and duplicate combinations
simplified_fixed_combinations = {}
seen = set()

Expand All @@ -61,7 +109,7 @@ def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
)

for base_col, related_cols in self.fixed_combinations.items():
# create a immutable set of base_col and related_cols
# create an immutable set of base_col and related_cols
combination = frozenset([base_col]) | frozenset(related_cols)

# if the combination has not been seen, add it to the simplified_fixed_combinations
Expand All @@ -70,9 +118,7 @@ def fit(self, metadata: Metadata | None = None, **kwargs: dict[str, Any]):
seen.add(combination)

self.simplified_fixed_combinations = simplified_fixed_combinations

self.has_column_mappings = False

self.fitted = True

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
Expand All @@ -83,10 +129,8 @@ def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
to optimize performance.

NOTE:
TODO-Enhance-Designed configuration interface. Use the user's configured Constrain if provided.
TODO-Enhance-Refactor Inspector by chain-of-responsibility, base one-to-one on Identified discrete_columns.
The current implementation has space for optimization:
- If column_mappings already exist, no recalculation is performed
- The column_mappings definition depends on the first batch of data from the DataLoader
- This might miss some edge cases where column relationships are very comprehensive
(e.g., some column correspondences might only appear in later batches)
Expand All @@ -102,6 +146,18 @@ def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
pd.DataFrame: The processed DataFrame (unchanged in this implementation)
"""

if self.is_been_specified:
logger.info(
"Converting data using FixedCombinationTransformer(been specified)... Finished (No action)."
)
return raw_data

if not self.is_exist_fixed_combinations:
logger.info(
"Converting data using FixedCombinationTransformer(not existed)... Finished (No action)."
)
return raw_data

if self.has_column_mappings:
logger.info(
"Converting data using FixedCombinationTransformer... Finished (No action)."
Expand Down Expand Up @@ -157,6 +213,17 @@ def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
Returns:
pd.DataFrame: The DataFrame with original values restored based on the defined mappings.
"""
if self.is_been_specified:
logger.info(
"Reverse converting data using FixedCombinationTransformer(been specified)... Finished (No action)."
)
return processed_data

if not self.is_exist_fixed_combinations:
logger.info(
"Reverse converting data using FixedCombinationTransformer(not existed)... Finished (No action)."
)
return processed_data

result_df = processed_data.copy()

Expand Down
139 changes: 139 additions & 0 deletions sdgx/data_processors/transformers/specific_combination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from typing import Dict, List, Set

import numpy as np
import pandas as pd

from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.data_processors.extension import hookimpl
from sdgx.data_processors.transformers.base import Transformer
from sdgx.utils import logger


class SpecificCombinationTransformer(Transformer):
"""
A transformer used to handle specific combinations of columns in tabular data.

The relationships between columns can be quite complex. Currently, we introduced `FixedCombinationTransformer`
is not capable of comprehensive automatic detection. This transformer allows users to manually specify the
mapping relationships between columns, specifically for multiple corresponding relationships. Users can define
multiple groups, with each group supporting multiple columns. The transformer will record the combination values
of each column, and in the `reverse_convert()`, it will restore any mismatched combinations from the recorded
relationships.

For example:

| Category A | Category B | Category C | Category D | Category E |
| :--------: | :--------: | :--------: | :--------: | :--------: |
| A1 | B1 | C1 | D1 | E1 |
| A1 | B1 | C2 | D2 | E2 |
| A2 | B2 | C1 | D1 | E3 |

Here user can specific combination like (Category A, Category B), (Category C, Category D, Category E).

For now, the `specific_combinations` passing by `Metadata`

"""

column_groups: List[Set[str]]
"""
Define a list where each element is a set containing string type column names
"""

mappings: Dict[frozenset, pd.DataFrame]
"""
Define a dictionary variable `mappings` where the keys are frozensets and the values are pandas DataFrame objects
"""

specified: bool
"""
Define a boolean that flag if user specified the combination, if true, that handle the `specific_combinations`
"""

def __init__(self):
self.column_groups: List[Set[str]] = []
self.mappings: Dict[frozenset, pd.DataFrame] = {}
self.specified = False

def fit(self, metadata: Metadata | None = None, tabular_data: DataLoader | pd.DataFrame = None):
"""
Study the combination relationships and value mapping of columns.

Args:
metadata: Metadata containing information about specific column combinations.
tabular_data: The tabular data to be fitted, can be a DataLoader object or a pandas DataFrame.
"""
specific_combinations = metadata.get("specific_combinations")
if specific_combinations is None or len(specific_combinations) == 0:
logger.info(
"Fit data using SpecificCombinationTransformer(No specified)... Finished (No action)."
)
self.fitted = True
return

# Create a mapping relationship for each group of columns
df = tabular_data
self.column_groups = [set(cols) for cols in specific_combinations]
for group in self.column_groups:
group_df = df[list(group)].drop_duplicates()
self.mappings[frozenset(group)] = group_df

self.fitted = True
self.specified = True
logger.info("SpecificCombinationTransformer Fitted.")

def convert(self, raw_data: pd.DataFrame) -> pd.DataFrame:
"""
Convert the raw data based on the learned mapping relationships.

Args:
raw_data: The raw data to be converted.

Returns:
The converted data.
"""
if not self.specified:
logger.info(
"Converting data using SpecificCombinationTransformer(No specified)... Finished (No action)."
)
return super().convert(raw_data)

logger.info("SpecificCombinationTransformer convert doing nothing...")
return super().convert(raw_data)

def reverse_convert(self, processed_data: pd.DataFrame) -> pd.DataFrame:
"""
Reverse convert the processed data to ensure it conforms to the original format.

Args:
processed_data: The processed data to be reverse converted.

Returns:
The reverse converted data.
"""
if not self.specified:
logger.info(
"Reverse converting data using SpecificCombinationTransformer(No specified)... Finished (No action)."
)
return processed_data

result_df = processed_data.copy()
n_rows = len(result_df)

# For each column-mapping groups, replace with random choice
# Here we random_indices for len(processed_data) from column-mapping and replaced processed_data.
for group in self.column_groups:
group_mapping = self.mappings[frozenset(group)]
group_cols = list(group)
random_indices = np.random.choice(len(group_mapping), size=n_rows)
random_mappings = group_mapping.iloc[random_indices]
result_df[group_cols] = random_mappings[group_cols].values

return result_df


@hookimpl
def register(manager):
manager.register("SpecificCombinationTransformer", SpecificCombinationTransformer)
Loading
Loading