Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the memory usage of Gaussian Copula training. #233

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions sdgx/models/components/optimize/sdv_copulas/data_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
import pandas as pd

from sdgx.models.components.sdv_ctgan.data_transformer import (
ColumnTransformInfo,
DataTransformer,
SpanInfo,
)
from sdgx.models.components.sdv_rdt.transformers import ClusterBasedNormalizer
from sdgx.models.components.sdv_rdt.transformers.categorical import FrequencyEncoder

# TODO(Enhance) - Use different type of Encoder for discrete, like ordered columns, high cardinality columns...


class StatisticDataTransformer(DataTransformer):
"""Data Transformer for statistical models like Gaussian Copula."""

def _fit_continuous(self, data):
"""Train ClusterBasedNormalizer for continuous columns."""
column_name = data.columns[0]
gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=1)
gm.fit(data, column_name)

return ColumnTransformInfo(
column_name=column_name,
column_type="continuous",
transform=gm,
output_info=[SpanInfo(1, "tanh")],
output_dimensions=1,
)

def _transform_continuous(self, column_transform_info, data):
"""Transform continuous column."""
gm = column_transform_info.transform
transformed = gm.transform(data)
return transformed[f"{data.columns[0]}.normalized"].to_numpy().reshape(-1, 1)

def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
"""Inverse transform continuous column."""
gm = column_transform_info.transform
column_name = column_transform_info.column_name

# Create dataframe
data = pd.DataFrame(
{
f"{column_name}.normalized": column_data.flatten(),
f"{column_name}.component": [0] * len(column_data), # virtual component
}
)

if sigmas is not None:
data[f"{column_name}.normalized"] = np.random.normal(
data[f"{column_name}.normalized"], sigmas[st]
)

# Reverse data
result = gm.reverse_transform(data)

# Ensure correct column
if column_name in result.columns:
return result[column_name]
else:
# Try first column
return result.iloc[:, 0]

def _fit_discrete(self, data):
"""Fit frequency encoder for discrete column."""
column_name = data.columns[0]
freq_encoder = FrequencyEncoder()
freq_encoder.fit(data, column_name)

# Save original unique values for inverse transform
self._discrete_values = (
{column_name: data[column_name].unique().tolist()}
if not hasattr(self, "_discrete_values")
else {**self._discrete_values, column_name: data[column_name].unique().tolist()}
)

return ColumnTransformInfo(
column_name=column_name,
column_type="discrete",
transform=freq_encoder,
output_info=[SpanInfo(1, "tanh")],
output_dimensions=1,
)

def _transform_discrete(self, column_transform_info, data):
"""Transform discrete column using frequency encoding."""
freq_encoder = column_transform_info.transform
return freq_encoder.transform(data).to_numpy().reshape(-1, 1)

def _inverse_transform_discrete(self, column_transform_info, column_data):
"""Inverse transform discrete column from frequency encoding."""
freq_encoder = column_transform_info.transform
column_name = column_transform_info.column_name

# Use frequency encoder to reverse transform
data = pd.DataFrame({column_name: column_data.flatten()})

# Get all possible category values
categories = freq_encoder.starts["category"].values

# Find the closest category for each frequency value
result = []
for val in data[column_name]:
# The index of the closest start point
starts = freq_encoder.starts.index.values
idx = np.abs(starts - val).argmin()
# Set which category does the closest start point belong to
result.append(categories[idx])

return pd.Series(result, index=data.index, dtype=freq_encoder.dtype)
5 changes: 4 additions & 1 deletion sdgx/models/statistics/single_table/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sdgx.data_loader import DataLoader
from sdgx.data_models.metadata import Metadata
from sdgx.exceptions import NonParametricError, SynthesizerInitError
from sdgx.models.components.optimize.sdv_copulas.data_transformer import (
StatisticDataTransformer,
)
from sdgx.models.components.sdv_copulas import multivariate
from sdgx.models.components.sdv_ctgan.data_transformer import DataTransformer
from sdgx.models.components.sdv_rdt.transformers import OneHotEncoder
Expand Down Expand Up @@ -138,7 +141,7 @@ def fit(self, metadata: Metadata, dataloader: DataLoader, *args, **kwargs):
self.metadata = metadata

# load the original transformer
self._transformer = DataTransformer()
self._transformer = StatisticDataTransformer()

# self._transformer.fit(processed_data, self.metadata[0])
self._transformer.fit(processed_data, self.discrete_cols)
Expand Down
Loading